diff --git a/.github/actions/setup-web/action.yml b/.github/actions/setup-web/action.yml index c57da7cb5f..54702c914a 100644 --- a/.github/actions/setup-web/action.yml +++ b/.github/actions/setup-web/action.yml @@ -1,33 +1,13 @@ name: Setup Web Environment -description: Setup pnpm, Node.js, and install web dependencies. - -inputs: - node-version: - description: Node.js version to use - required: false - default: "22" - install-dependencies: - description: Whether to install web dependencies after setting up Node.js - required: false - default: "true" runs: using: composite steps: - - name: Install pnpm - uses: pnpm/action-setup@41ff72655975bd51cab0327fa583b6e92b6d3061 # v4.2.0 + - name: Setup Vite+ + uses: voidzero-dev/setup-vp@b5d848f5a62488f3d3d920f8aa6ac318a60c5f07 # v1 with: - package_json_file: web/package.json - run_install: false - - - name: Setup Node.js - uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 - with: - node-version: ${{ inputs.node-version }} - cache: pnpm - cache-dependency-path: ./web/pnpm-lock.yaml - - - name: Install dependencies - if: ${{ inputs.install-dependencies == 'true' }} - shell: bash - run: pnpm --dir web install --frozen-lockfile + node-version-file: "./web/.nvmrc" + cache: true + run-install: | + - cwd: ./web + args: ['--frozen-lockfile'] diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 2af3b130ad..80f892589d 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -102,13 +102,11 @@ jobs: - name: Setup web environment if: steps.web-changes.outputs.any_changed == 'true' uses: ./.github/actions/setup-web - with: - node-version: "24" - name: ESLint autofix if: steps.web-changes.outputs.any_changed == 'true' run: | cd web - pnpm eslint --concurrency=2 --prune-suppressions --quiet || true + vp exec eslint --concurrency=2 --prune-suppressions --quiet || true - uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml index ef2e3c7bb4..fd104e9496 100644 --- a/.github/workflows/main-ci.yml +++ b/.github/workflows/main-ci.yml @@ -62,6 +62,9 @@ jobs: needs: check-changes if: needs.check-changes.outputs.web-changed == 'true' uses: ./.github/workflows/web-tests.yml + with: + base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }} + head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} style-check: name: Style Check diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index b284694530..868bacc6e5 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -88,7 +88,7 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web run: | - pnpm run lint:ci + vp run lint:ci # pnpm run lint:report # continue-on-error: true @@ -102,17 +102,17 @@ jobs: - name: Web tsslint if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: pnpm run lint:tss + run: vp run lint:tss - name: Web type check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: pnpm run type-check + run: vp run type-check - name: Web dead code check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: pnpm run knip + run: vp run knip superlinter: name: SuperLinter diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index ff07313ebe..62724c84e5 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -50,8 +50,6 @@ jobs: - name: Setup web environment uses: ./.github/actions/setup-web - with: - install-dependencies: "false" - name: Detect changed files and generate diff id: detect_changes diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 33e9170b02..fd2b941ce3 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -2,6 +2,13 @@ name: Web Tests on: workflow_call: + inputs: + base_sha: + required: false + type: string + head_sha: + required: false + type: string permissions: contents: read @@ -14,6 +21,8 @@ jobs: test: name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }}) runs-on: ubuntu-latest + env: + VITEST_COVERAGE_SCOPE: app-components strategy: fail-fast: false matrix: @@ -34,7 +43,7 @@ jobs: uses: ./.github/actions/setup-web - name: Run tests - run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage + run: vp test run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage - name: Upload blob report if: ${{ !cancelled() }} @@ -50,6 +59,8 @@ jobs: if: ${{ !cancelled() }} needs: [test] runs-on: ubuntu-latest + env: + VITEST_COVERAGE_SCOPE: app-components defaults: run: shell: bash @@ -59,6 +70,7 @@ jobs: - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: + fetch-depth: 0 persist-credentials: false - name: Setup web environment @@ -72,7 +84,13 @@ jobs: merge-multiple: true - name: Merge reports - run: pnpm vitest --merge-reports --coverage --silent=passed-only + run: vp test --merge-reports --reporter=json --reporter=agent --coverage + + - name: Check app/components diff coverage + env: + BASE_SHA: ${{ inputs.base_sha }} + HEAD_SHA: ${{ inputs.head_sha }} + run: node ./scripts/check-components-diff-coverage.mjs - name: Coverage Summary if: always() @@ -429,4 +447,4 @@ jobs: - name: Web build check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: pnpm run build + run: vp run build diff --git a/api/.env.example b/api/.env.example index ab8b6c5287..8fbe2e4643 100644 --- a/api/.env.example +++ b/api/.env.example @@ -188,7 +188,6 @@ VECTOR_INDEX_NAME_PREFIX=Vector_index # Weaviate configuration WEAVIATE_ENDPOINT=http://localhost:8080 WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih -WEAVIATE_GRPC_ENABLED=false WEAVIATE_BATCH_SIZE=100 WEAVIATE_TOKENIZATION=word diff --git a/api/.importlinter b/api/.importlinter index 5c0a6e1288..8dffc3506b 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -43,7 +43,6 @@ forbidden_modules = extensions.ext_redis allow_indirect_imports = True ignore_imports = - dify_graph.nodes.agent.agent_node -> extensions.ext_database dify_graph.nodes.llm.node -> extensions.ext_database dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis @@ -90,9 +89,6 @@ forbidden_modules = core.trigger core.variables ignore_imports = - dify_graph.nodes.agent.agent_node -> core.model_manager - dify_graph.nodes.agent.agent_node -> core.provider_manager - dify_graph.nodes.agent.agent_node -> core.tools.tool_manager dify_graph.nodes.llm.llm_utils -> core.model_manager dify_graph.nodes.llm.protocols -> core.model_manager dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model @@ -100,8 +96,6 @@ ignore_imports = dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler dify_graph.nodes.tool.tool_node -> core.tools.tool_engine dify_graph.nodes.tool.tool_node -> core.tools.tool_manager - dify_graph.nodes.agent.agent_node -> core.agent.entities - dify_graph.nodes.agent.agent_node -> core.agent.plugin_entities dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform @@ -110,12 +104,10 @@ ignore_imports = dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer - dify_graph.nodes.agent.agent_node -> models.model dify_graph.nodes.llm.node -> core.helper.code_executor dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output dify_graph.nodes.llm.node -> core.model_manager - dify_graph.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util @@ -126,15 +118,11 @@ ignore_imports = dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods dify_graph.nodes.llm.node -> models.dataset - dify_graph.nodes.agent.agent_node -> core.tools.utils.message_transformer dify_graph.nodes.llm.file_saver -> core.tools.signature dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager dify_graph.nodes.tool.tool_node -> core.tools.errors - dify_graph.nodes.agent.agent_node -> extensions.ext_database dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.nodes.agent.agent_node -> models dify_graph.nodes.llm.node -> models.model - dify_graph.nodes.agent.agent_node -> services dify_graph.nodes.tool.tool_node -> services dify_graph.model_runtime.model_providers.__base.ai_model -> configs dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis diff --git a/api/configs/middleware/vdb/weaviate_config.py b/api/configs/middleware/vdb/weaviate_config.py index 6f4fccaa7f..2d1216c0d1 100644 --- a/api/configs/middleware/vdb/weaviate_config.py +++ b/api/configs/middleware/vdb/weaviate_config.py @@ -17,11 +17,6 @@ class WeaviateConfig(BaseSettings): default=None, ) - WEAVIATE_GRPC_ENABLED: bool = Field( - description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)", - default=True, - ) - WEAVIATE_GRPC_ENDPOINT: str | None = Field( description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')", default=None, diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index dd982b6d7b..2025048e09 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -1,5 +1,4 @@ import json -from enum import StrEnum from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field @@ -11,6 +10,7 @@ from controllers.console.wraps import account_initialization_required, edit_perm from extensions.ext_database import db from fields.app_fields import app_server_fields from libs.login import current_account_with_tenant, login_required +from models.enums import AppMCPServerStatus from models.model import AppMCPServer DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -19,11 +19,6 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" app_server_model = console_ns.model("AppServer", app_server_fields) -class AppMCPServerStatus(StrEnum): - ACTIVE = "active" - INACTIVE = "inactive" - - class MCPServerCreatePayload(BaseModel): description: str | None = Field(default=None, description="Server description") parameters: dict = Field(..., description="Server parameters configuration") @@ -117,9 +112,10 @@ class AppMCPServerController(Resource): server.parameters = json.dumps(payload.parameters, ensure_ascii=False) if payload.status: - if payload.status not in [status.value for status in AppMCPServerStatus]: + try: + server.status = AppMCPServerStatus(payload.status) + except ValueError: raise ValueError("Invalid status") - server.status = payload.status db.session.commit() return server diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 708df62642..0d8960c9bd 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -43,6 +43,7 @@ from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required from models import AccountIntegrate, InvitationCode +from models.account import AccountStatus, InvitationCodeStatus from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -215,7 +216,7 @@ class AccountInitApi(Resource): db.session.query(InvitationCode) .where( InvitationCode.code == args.invitation_code, - InvitationCode.status == "unused", + InvitationCode.status == InvitationCodeStatus.UNUSED, ) .first() ) @@ -223,7 +224,7 @@ class AccountInitApi(Resource): if not invitation_code: raise InvalidInvitationCodeError() - invitation_code.status = "used" + invitation_code.status = InvitationCodeStatus.USED invitation_code.used_at = naive_utc_now() invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id @@ -231,7 +232,7 @@ class AccountInitApi(Resource): account.interface_language = args.interface_language account.timezone = args.timezone account.interface_theme = "light" - account.status = "active" + account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() db.session.commit() diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index 2f06f72f29..ee537367c7 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -5,6 +5,7 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource from pydantic import BaseModel, Field +from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden from configs import dify_config @@ -169,6 +170,20 @@ register_enum_models( ) +def _read_upload_content(file: FileStorage, max_size: int) -> bytes: + """ + Read the uploaded file and validate its actual size before delegating to the plugin service. + + FileStorage.content_length is not reliable for multipart test uploads and may be zero even when + content exists, so the controllers validate against the loaded bytes instead. + """ + content = file.read() + if len(content) > max_size: + raise ValueError("File size exceeds the maximum allowed size") + + return content + + @console_ns.route("/workspaces/current/plugin/debugging-key") class PluginDebuggingKeyApi(Resource): @setup_required @@ -284,12 +299,7 @@ class PluginUploadFromPkgApi(Resource): _, tenant_id = current_account_with_tenant() file = request.files["pkg"] - - # check file size - if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE: - raise ValueError("File size exceeds the maximum allowed size") - - content = file.read() + content = _read_upload_content(file, dify_config.PLUGIN_MAX_PACKAGE_SIZE) try: response = PluginService.upload_pkg(tenant_id, content) except PluginDaemonClientSideError as e: @@ -328,12 +338,7 @@ class PluginUploadFromBundleApi(Resource): _, tenant_id = current_account_with_tenant() file = request.files["bundle"] - - # check file size - if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE: - raise ValueError("File size exceeds the maximum allowed size") - - content = file.read() + content = _read_upload_content(file, dify_config.PLUGIN_MAX_BUNDLE_SIZE) try: response = PluginService.upload_bundle(tenant_id, content) except PluginDaemonClientSideError as e: diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index edf3ac393c..766d95b3dd 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -114,6 +114,7 @@ def get_user_tenant(view_func: Callable[P, R]): def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]): def decorator(view_func: Callable[P, R]): + @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs): try: data = request.get_json() diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 2bc6640807..9ddaaa315b 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -6,13 +6,13 @@ from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import Session from controllers.common.schema import register_schema_model -from controllers.console.app.mcp_server import AppMCPServerStatus from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request from dify_graph.variables.input_entities import VariableEntity from extensions.ext_database import db from libs import helper +from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index c6ecd5509b..9271ed10bd 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -6,6 +6,7 @@ from typing import Any from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit +from core.agent.errors import AgentMaxIterationError from core.agent.output_parser.cot_output_parser import CotAgentOutputParser from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent @@ -22,7 +23,6 @@ from dify_graph.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.nodes.agent.exc import AgentMaxIterationError from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/errors.py b/api/core/agent/errors.py new file mode 100644 index 0000000000..ed504d500a --- /dev/null +++ b/api/core/agent/errors.py @@ -0,0 +1,9 @@ +class AgentMaxIterationError(Exception): + """Raised when an agent runner exceeds the configured max iteration count.""" + + def __init__(self, max_iteration: int): + self.max_iteration = max_iteration + super().__init__( + f"Agent exceeded the maximum iteration limit of {max_iteration}. " + f"The agent was unable to complete the task within the allowed number of iterations." + ) diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 3271fe319b..5e13a13b21 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -5,6 +5,7 @@ from copy import deepcopy from typing import Any, Union from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.errors import AgentMaxIterationError from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform @@ -25,7 +26,6 @@ from dify_graph.model_runtime.entities import ( UserPromptMessage, ) from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.nodes.agent.exc import AgentMaxIterationError from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 02ec96f209..5c9bc43992 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -114,7 +114,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) - if isinstance(sub_stream_response, ErrorStreamResponse): + elif isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index e35e9d9408..0c146c388f 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -113,7 +113,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) - if isinstance(sub_stream_response, ErrorStreamResponse): + elif isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 3aa1161fd8..f23ee7f89f 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -113,7 +113,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) - if isinstance(sub_stream_response, ErrorStreamResponse): + elif isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 7ef6ff7cc2..8986164fe7 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,7 +3,10 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from pydantic import ValidationError + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.agent_strategy import AgentStrategyInfo from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.entities.queue_entities import ( AppQueueEvent, @@ -30,8 +33,10 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.workflow.node_factory import DifyNodeFactory +from core.workflow.node_resolution import resolve_workflow_node_class from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.graph import Graph from dify_graph.graph_engine.layers.base import GraphEngineLayer @@ -62,8 +67,6 @@ from dify_graph.graph_events import ( NodeRunSucceededEvent, ) from dify_graph.graph_events.graph import GraphRunAbortedEvent -from dify_graph.nodes import NodeType -from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool @@ -303,10 +306,12 @@ class WorkflowBasedAppRunner: if not target_node_config: raise ValueError(f"{node_type_label} node id not found in workflow graph") + target_node_config = NodeConfigDictAdapter.validate_python(target_node_config) + # Get node class - node_type = NodeType(target_node_config.get("data", {}).get("type")) - node_version = target_node_config.get("data", {}).get("version", "1") - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] + node_type = target_node_config["data"].type + node_version = str(target_node_config["data"].version) + node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version) # Use the variable pool from graph_runtime_state instead of creating a new one variable_pool = graph_runtime_state.variable_pool @@ -334,6 +339,18 @@ class WorkflowBasedAppRunner: return graph, variable_pool + @staticmethod + def _build_agent_strategy_info(event: NodeRunStartedEvent) -> AgentStrategyInfo | None: + raw_agent_strategy = event.extras.get("agent_strategy") + if raw_agent_strategy is None: + return None + + try: + return AgentStrategyInfo.model_validate(raw_agent_strategy) + except ValidationError: + logger.warning("Invalid agent strategy payload for node %s", event.node_id, exc_info=True) + return None + def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): """ Handle event @@ -419,7 +436,7 @@ class WorkflowBasedAppRunner: start_at=event.start_at, in_iteration_id=event.in_iteration_id, in_loop_id=event.in_loop_id, - agent_strategy=event.agent_strategy, + agent_strategy=self._build_agent_strategy_info(event), provider_type=event.provider_type, provider_id=event.provider_id, ) diff --git a/api/core/app/entities/__init__.py b/api/core/app/entities/__init__.py index e69de29bb2..8e41acee32 100644 --- a/api/core/app/entities/__init__.py +++ b/api/core/app/entities/__init__.py @@ -0,0 +1,3 @@ +from .agent_strategy import AgentStrategyInfo + +__all__ = ["AgentStrategyInfo"] diff --git a/api/core/app/entities/agent_strategy.py b/api/core/app/entities/agent_strategy.py new file mode 100644 index 0000000000..b063a12f4f --- /dev/null +++ b/api/core/app/entities/agent_strategy.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel, ConfigDict + + +class AgentStrategyInfo(BaseModel): + name: str + icon: str | None = None + + model_config = ConfigDict(extra="forbid") diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index d42df0d1bf..2d1508f0cb 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -5,8 +5,8 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field +from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities import AgentNodeStrategyInit from dify_graph.entities.pause_reason import PauseReason from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import WorkflowNodeExecutionMetadataKey @@ -314,7 +314,7 @@ class QueueNodeStartedEvent(AppQueueEvent): in_iteration_id: str | None = None in_loop_id: str | None = None start_at: datetime - agent_strategy: AgentNodeStrategyInit | None = None + agent_strategy: AgentStrategyInfo | None = None # FIXME(-LAN-): only for ToolNode, need to refactor provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index b58dae0ff2..46a8ab52f2 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -4,8 +4,8 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field +from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities import AgentNodeStrategyInit from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage @@ -349,7 +349,7 @@ class NodeStartStreamResponse(StreamResponse): extras: dict[str, object] = Field(default_factory=dict) iteration_id: str | None = None loop_id: str | None = None - agent_strategy: AgentNodeStrategyInit | None = None + agent_strategy: AgentStrategyInfo | None = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index d0279349ca..b054409681 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -12,6 +12,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import CreatorUserRole _logger = logging.getLogger(__name__) @@ -38,7 +39,9 @@ class DatasetIndexToolCallbackHandler: source="app", source_app_id=self._app_id, created_by_role=( - "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" + CreatorUserRole.ACCOUNT + if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatorUserRole.END_USER ), created_by=self._user_id, ) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 6a09dbff35..c8848336d9 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -193,7 +193,8 @@ class LLMGenerator: error_step = "generate rule config" except Exception as e: logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) - rule_config["error"] = str(e) + error = str(e) + error_step = "generate rule config" rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" @@ -279,7 +280,8 @@ class LLMGenerator: except Exception as e: logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) - rule_config["error"] = str(e) + error = str(e) + error_step = "handle unexpected exception" rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 33782e7949..9ac753240b 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -628,10 +628,10 @@ class TraceTask: if not message_data: return {} conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id) - conversation_mode = db.session.scalars(conversation_mode_stmt).all() - if not conversation_mode or len(conversation_mode) == 0: + conversation_modes = db.session.scalars(conversation_mode_stmt).all() + if not conversation_modes or len(conversation_modes) == 0: return {} - conversation_mode = conversation_mode[0] + conversation_mode = conversation_modes[0] created_at = message_data.created_at inputs = message_data.message diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index bfa662b9f6..ce5813a294 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -191,7 +191,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /): except ValueError: raise except Exception: - raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.") + raise ValueError(f"The tool parameter value {repr(value)} is not in correct type of {as_normal_type(typ)}.") def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any): diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index f82c3a846b..c29a463bb6 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -627,7 +627,7 @@ class ProviderManager: tenant_id=tenant_id, # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, - provider_type=ProviderType.SYSTEM.value, + provider_type=ProviderType.SYSTEM, quota_type=quota.quota_type, quota_limit=0, # type: ignore quota_used=0, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8243170c62..fcd3cceb59 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -83,6 +83,7 @@ from models.dataset import ( ) from models.dataset import Document as DatasetDocument from models.dataset import Document as DocumentModel +from models.enums import CreatorUserRole from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureService @@ -1009,7 +1010,7 @@ class DatasetRetrieval: content=json.dumps(contents), source="app", source_app_id=app_id, - created_by_role=user_from, + created_by_role=CreatorUserRole(user_from), created_by=user_id, ) dataset_queries.append(dataset_query) diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 770df8b050..55e96515ac 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -146,7 +146,9 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): # No sequence number generation needed anymore - db_model.type = domain_model.workflow_type + from models.workflow import WorkflowType as ModelWorkflowType + + db_model.type = ModelWorkflowType(domain_model.workflow_type.value) db_model.version = domain_model.workflow_version db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 50105bd707..20cdb3e57f 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -113,17 +113,26 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.get_credentials_schema_by_type(CredentialType.API_KEY) - def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: + def get_credentials_schema_by_type(self, credential_type: CredentialType | str) -> list[ProviderConfig]: """ returns the credentials schema of the provider - :param credential_type: the type of the credential - :return: the credentials schema of the provider + :param credential_type: the type of the credential, as CredentialType or str; str values + are normalized via CredentialType.of and may raise ValueError for invalid values. + :return: list[ProviderConfig] for CredentialType.OAUTH2 or CredentialType.API_KEY, an + empty list for CredentialType.UNAUTHORIZED or missing schemas. + + Reads from self.entity.oauth_schema and self.entity.credentials_schema. + Raises ValueError for invalid credential types. """ - if credential_type == CredentialType.OAUTH2.value: + if isinstance(credential_type, str): + credential_type = CredentialType.of(credential_type) + if credential_type == CredentialType.OAUTH2: return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] if credential_type == CredentialType.API_KEY: return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] + if credential_type == CredentialType.UNAUTHORIZED: + return [] raise ValueError(f"Invalid credential type: {credential_type}") def get_oauth_client_schema(self) -> list[ProviderConfig]: diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index f6eccc734b..210f488afc 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -137,6 +137,7 @@ class ToolFileManager: session.add(tool_file) session.commit() + session.refresh(tool_file) return tool_file diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 9b7b3de614..442a2434d5 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -19,6 +19,7 @@ from core.trigger.debug.events import ( build_plugin_pool_key, build_webhook_pool_key, ) +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig @@ -41,10 +42,10 @@ class TriggerDebugEventPoller(ABC): app_id: str user_id: str tenant_id: str - node_config: Mapping[str, Any] + node_config: NodeConfigDict node_id: str - def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: Mapping[str, Any], node_id: str): + def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: NodeConfigDict, node_id: str): self.tenant_id = tenant_id self.user_id = user_id self.app_id = app_id @@ -60,7 +61,7 @@ class PluginTriggerDebugEventPoller(TriggerDebugEventPoller): def poll(self) -> TriggerDebugEvent | None: from services.trigger.trigger_service import TriggerService - plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config.get("data", {})) + plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config["data"], from_attributes=True) provider_id = TriggerProviderID(plugin_trigger_data.provider_id) pool_key: str = build_plugin_pool_key( name=plugin_trigger_data.event_name, diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 8c6b1dedee..bc4e0eda71 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -1,5 +1,5 @@ -from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, cast, final +from collections.abc import Callable, Mapping +from typing import TYPE_CHECKING, Any, TypeAlias, cast, final from sqlalchemy import select from sqlalchemy.orm import Session @@ -22,7 +22,15 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.summary_index.summary_index import SummaryIndex from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.tools.tool_file_manager import ToolFileManager -from dify_graph.entities.graph_config import NodeConfigDict +from core.workflow.node_resolution import resolve_workflow_node_class +from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer +from core.workflow.nodes.agent.plugin_strategy_adapter import ( + PluginAgentStrategyPresentationProvider, + PluginAgentStrategyResolver, +) +from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import NodeType, SystemVariableKey from dify_graph.file.file_manager import file_manager @@ -31,26 +39,18 @@ from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.memory import PromptMessageMemory from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from dify_graph.nodes.base.node import Node -from dify_graph.nodes.code.code_node import CodeNode, WorkflowCodeExecutor +from dify_graph.nodes.code.code_node import WorkflowCodeExecutor from dify_graph.nodes.code.entities import CodeLanguage from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.nodes.datasource import DatasourceNode -from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig -from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode -from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from dify_graph.nodes.llm.entities import ModelConfig +from dify_graph.nodes.document_extractor import UnstructuredApiConfig +from dify_graph.nodes.http_request import build_http_request_config +from dify_graph.nodes.llm.entities import LLMNodeData from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING -from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData from dify_graph.nodes.template_transform.template_renderer import ( CodeExecutorJinja2TemplateRenderer, ) -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode -from dify_graph.nodes.tool.tool_node import ToolNode from dify_graph.variables.segments import StringSegment from extensions.ext_database import db from models.model import Conversation @@ -60,6 +60,9 @@ if TYPE_CHECKING: from dify_graph.runtime import GraphRuntimeState +LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData + + def fetch_memory( *, conversation_id: str | None, @@ -100,10 +103,7 @@ class DefaultWorkflowCodeExecutor: @final class DifyNodeFactory(NodeFactory): """ - Default implementation of NodeFactory that uses the traditional node mapping. - - This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING - and instantiating the appropriate node class. + Default implementation of NodeFactory that resolves node classes from the live registry. """ def __init__( @@ -146,6 +146,10 @@ class DifyNodeFactory(NodeFactory): ) self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id) + self._agent_strategy_resolver = PluginAgentStrategyResolver() + self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider() + self._agent_runtime_support = AgentRuntimeSupport() + self._agent_message_transformer = AgentMessageTransformer() @staticmethod def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext: @@ -157,178 +161,125 @@ class DifyNodeFactory(NodeFactory): return DifyRunContext.model_validate(raw_ctx) @override - def create_node(self, node_config: NodeConfigDict) -> Node: + def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node: """ Create a Node instance from node configuration data using the traditional mapping. :param node_config: node configuration dictionary containing type and other data :return: initialized Node instance - :raises ValueError: if node type is unknown or configuration is invalid + :raises ValueError: if node_config fails NodeConfigDict/BaseNodeData validation + (including pydantic ValidationError, which subclasses ValueError), + if node type is unknown, or if no implementation exists for the resolved version """ - # Get node_id from config - node_id = node_config["id"] - - # Get node type from config - node_data = node_config["data"] - try: - node_type = NodeType(node_data["type"]) - except ValueError: - raise ValueError(f"Unknown node type: {node_data['type']}") - - # Get node class - node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type) - if not node_mapping: - raise ValueError(f"No class mapping found for node type: {node_type}") - - latest_node_class = node_mapping.get(LATEST_VERSION) - node_version = str(node_data.get("version", "1")) - matched_node_class = node_mapping.get(node_version) - node_class = matched_node_class or latest_node_class - if not node_class: - raise ValueError(f"No latest version class found for node type: {node_type}") - - # Create node instance - if node_type == NodeType.CODE: - return CodeNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - code_executor=self._code_executor, - code_limits=self._code_limits, - ) - - if node_type == NodeType.TEMPLATE_TRANSFORM: - return TemplateTransformNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - template_renderer=self._template_renderer, - max_output_length=self._template_transform_max_output_length, - ) - - if node_type == NodeType.HTTP_REQUEST: - return HttpRequestNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - http_request_config=self._http_request_config, - http_client=self._http_request_http_client, - tool_file_manager_factory=self._http_request_tool_file_manager_factory, - file_manager=self._http_request_file_manager, - ) - - if node_type == NodeType.HUMAN_INPUT: - return HumanInputNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), - ) - - if node_type == NodeType.KNOWLEDGE_INDEX: - return KnowledgeIndexNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - index_processor=IndexProcessor(), - summary_index_service=SummaryIndex(), - ) - - if node_type == NodeType.LLM: - model_instance = self._build_model_instance_for_llm_node(node_data) - memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance) - return LLMNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - credentials_provider=self._llm_credentials_provider, - model_factory=self._llm_model_factory, - model_instance=model_instance, - memory=memory, - http_client=self._http_request_http_client, - ) - - if node_type == NodeType.DATASOURCE: - return DatasourceNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - datasource_manager=DatasourceManager, - ) - - if node_type == NodeType.KNOWLEDGE_RETRIEVAL: - return KnowledgeRetrievalNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - rag_retrieval=self._rag_retrieval, - ) - - if node_type == NodeType.DOCUMENT_EXTRACTOR: - return DocumentExtractorNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - unstructured_api_config=self._document_extractor_unstructured_api_config, - http_client=self._http_request_http_client, - ) - - if node_type == NodeType.QUESTION_CLASSIFIER: - model_instance = self._build_model_instance_for_llm_node(node_data) - memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance) - return QuestionClassifierNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - credentials_provider=self._llm_credentials_provider, - model_factory=self._llm_model_factory, - model_instance=model_instance, - memory=memory, - http_client=self._http_request_http_client, - ) - - if node_type == NodeType.PARAMETER_EXTRACTOR: - model_instance = self._build_model_instance_for_llm_node(node_data) - memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance) - return ParameterExtractorNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - credentials_provider=self._llm_credentials_provider, - model_factory=self._llm_model_factory, - model_instance=model_instance, - memory=memory, - ) - - if node_type == NodeType.TOOL: - return ToolNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - tool_file_manager_factory=self._http_request_tool_file_manager_factory(), - ) - + typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + node_id = typed_node_config["id"] + node_data = typed_node_config["data"] + node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) + node_type = node_data.type + node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = { + NodeType.CODE: lambda: { + "code_executor": self._code_executor, + "code_limits": self._code_limits, + }, + NodeType.TEMPLATE_TRANSFORM: lambda: { + "template_renderer": self._template_renderer, + "max_output_length": self._template_transform_max_output_length, + }, + NodeType.HTTP_REQUEST: lambda: { + "http_request_config": self._http_request_config, + "http_client": self._http_request_http_client, + "tool_file_manager_factory": self._http_request_tool_file_manager_factory, + "file_manager": self._http_request_file_manager, + }, + NodeType.HUMAN_INPUT: lambda: { + "form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), + }, + NodeType.KNOWLEDGE_INDEX: lambda: { + "index_processor": IndexProcessor(), + "summary_index_service": SummaryIndex(), + }, + NodeType.LLM: lambda: self._build_llm_compatible_node_init_kwargs( + node_class=node_class, + node_data=node_data, + include_http_client=True, + ), + NodeType.DATASOURCE: lambda: { + "datasource_manager": DatasourceManager, + }, + NodeType.KNOWLEDGE_RETRIEVAL: lambda: { + "rag_retrieval": self._rag_retrieval, + }, + NodeType.DOCUMENT_EXTRACTOR: lambda: { + "unstructured_api_config": self._document_extractor_unstructured_api_config, + "http_client": self._http_request_http_client, + }, + NodeType.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( + node_class=node_class, + node_data=node_data, + include_http_client=True, + ), + NodeType.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( + node_class=node_class, + node_data=node_data, + include_http_client=False, + ), + NodeType.TOOL: lambda: { + "tool_file_manager_factory": self._http_request_tool_file_manager_factory(), + }, + NodeType.AGENT: lambda: { + "strategy_resolver": self._agent_strategy_resolver, + "presentation_provider": self._agent_strategy_presentation_provider, + "runtime_support": self._agent_runtime_support, + "message_transformer": self._agent_message_transformer, + }, + } + node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})() return node_class( id=node_id, - config=node_config, + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, + **node_init_kwargs, ) - def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance: - node_data_model = ModelConfig.model_validate(node_data["model"]) + @staticmethod + def _validate_resolved_node_data(node_class: type[Node], node_data: BaseNodeData) -> BaseNodeData: + """ + Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class. + """ + return node_class.validate_node_data(node_data) + + @staticmethod + def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + return resolve_workflow_node_class(node_type=node_type, node_version=node_version) + + def _build_llm_compatible_node_init_kwargs( + self, + *, + node_class: type[Node], + node_data: BaseNodeData, + include_http_client: bool, + ) -> dict[str, object]: + validated_node_data = cast( + LLMCompatibleNodeData, + self._validate_resolved_node_data(node_class=node_class, node_data=node_data), + ) + model_instance = self._build_model_instance_for_llm_node(validated_node_data) + node_init_kwargs: dict[str, object] = { + "credentials_provider": self._llm_credentials_provider, + "model_factory": self._llm_model_factory, + "model_instance": model_instance, + "memory": self._build_memory_for_llm_node( + node_data=validated_node_data, + model_instance=model_instance, + ), + } + if include_http_client: + node_init_kwargs["http_client"] = self._http_request_http_client + return node_init_kwargs + + def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance: + node_data_model = node_data.model if not node_data_model.mode: raise LLMModeRequiredError("LLM mode is required.") @@ -364,14 +315,12 @@ class DifyNodeFactory(NodeFactory): def _build_memory_for_llm_node( self, *, - node_data: Mapping[str, Any], + node_data: LLMCompatibleNodeData, model_instance: ModelInstance, ) -> PromptMessageMemory | None: - raw_memory_config = node_data.get("memory") - if raw_memory_config is None: + if node_data.memory is None: return None - node_memory = MemoryConfig.model_validate(raw_memory_config) conversation_id_variable = self.graph_runtime_state.variable_pool.get( ["sys", SystemVariableKey.CONVERSATION_ID] ) @@ -381,6 +330,6 @@ class DifyNodeFactory(NodeFactory): return fetch_memory( conversation_id=conversation_id, app_id=self._dify_context.app_id, - node_data_memory=node_memory, + node_data_memory=node_data.memory, model_instance=model_instance, ) diff --git a/api/core/workflow/node_resolution.py b/api/core/workflow/node_resolution.py new file mode 100644 index 0000000000..b922c28165 --- /dev/null +++ b/api/core/workflow/node_resolution.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from collections.abc import Mapping +from importlib import import_module + +from dify_graph.enums import NodeType +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.node_mapping import LATEST_VERSION, get_node_type_classes_mapping + +_WORKFLOW_NODE_MODULES = ("core.workflow.nodes.agent",) +_workflow_nodes_registered = False + + +def ensure_workflow_nodes_registered() -> None: + """Import workflow-local node modules so they can register with `Node.__init_subclass__`.""" + global _workflow_nodes_registered + + if _workflow_nodes_registered: + return + + for module_name in _WORKFLOW_NODE_MODULES: + import_module(module_name) + + _workflow_nodes_registered = True + + +def get_workflow_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]: + ensure_workflow_nodes_registered() + return get_node_type_classes_mapping() + + +def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + node_mapping = get_workflow_node_type_classes_mapping().get(node_type) + if not node_mapping: + raise ValueError(f"No class mapping found for node type: {node_type}") + + latest_node_class = node_mapping.get(LATEST_VERSION) + matched_node_class = node_mapping.get(node_version) + node_class = matched_node_class or latest_node_class + if not node_class: + raise ValueError(f"No latest version class found for node type: {node_type}") + return node_class diff --git a/api/tests/unit_tests/core/model_runtime/__base/__init__.py b/api/core/workflow/nodes/__init__.py similarity index 100% rename from api/tests/unit_tests/core/model_runtime/__base/__init__.py rename to api/core/workflow/nodes/__init__.py diff --git a/api/core/workflow/nodes/agent/__init__.py b/api/core/workflow/nodes/agent/__init__.py new file mode 100644 index 0000000000..ba6c667194 --- /dev/null +++ b/api/core/workflow/nodes/agent/__init__.py @@ -0,0 +1,4 @@ +from .agent_node import AgentNode +from .entities import AgentNodeData + +__all__ = ["AgentNode", "AgentNodeData"] diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py new file mode 100644 index 0000000000..c1b423d69d --- /dev/null +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any + +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import NodeType, SystemVariableKey, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser + +from .entities import AgentNodeData +from .exceptions import ( + AgentInvocationError, + AgentMessageTransformError, +) +from .message_transformer import AgentMessageTransformer +from .runtime_support import AgentRuntimeSupport +from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver + +if TYPE_CHECKING: + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState + + +class AgentNode(Node[AgentNodeData]): + node_type = NodeType.AGENT + + _strategy_resolver: AgentStrategyResolver + _presentation_provider: AgentStrategyPresentationProvider + _runtime_support: AgentRuntimeSupport + _message_transformer: AgentMessageTransformer + + def __init__( + self, + id: str, + config: NodeConfigDict, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + *, + strategy_resolver: AgentStrategyResolver, + presentation_provider: AgentStrategyPresentationProvider, + runtime_support: AgentRuntimeSupport, + message_transformer: AgentMessageTransformer, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._strategy_resolver = strategy_resolver + self._presentation_provider = presentation_provider + self._runtime_support = runtime_support + self._message_transformer = message_transformer + + @classmethod + def version(cls) -> str: + return "1" + + def populate_start_event(self, event) -> None: + dify_ctx = self.require_dify_context() + event.extras["agent_strategy"] = { + "name": self.node_data.agent_strategy_name, + "icon": self._presentation_provider.get_icon( + tenant_id=dify_ctx.tenant_id, + agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, + ), + } + + def _run(self) -> Generator[NodeEventBase, None, None]: + from core.plugin.impl.exc import PluginDaemonClientSideError + + dify_ctx = self.require_dify_context() + + try: + strategy = self._strategy_resolver.resolve( + tenant_id=dify_ctx.tenant_id, + agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, + agent_strategy_name=self.node_data.agent_strategy_name, + ) + except Exception as e: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error=f"Failed to get agent strategy: {str(e)}", + ), + ) + return + + agent_parameters = strategy.get_parameters() + + parameters = self._runtime_support.build_parameters( + agent_parameters=agent_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, + strategy=strategy, + tenant_id=dify_ctx.tenant_id, + app_id=dify_ctx.app_id, + invoke_from=dify_ctx.invoke_from, + ) + parameters_for_log = self._runtime_support.build_parameters( + agent_parameters=agent_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, + strategy=strategy, + tenant_id=dify_ctx.tenant_id, + app_id=dify_ctx.app_id, + invoke_from=dify_ctx.invoke_from, + for_log=True, + ) + credentials = self._runtime_support.build_credentials(parameters=parameters) + + conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + + try: + message_stream = strategy.invoke( + params=parameters, + user_id=dify_ctx.user_id, + app_id=dify_ctx.app_id, + conversation_id=conversation_id.text if conversation_id else None, + credentials=credentials, + ) + except Exception as e: + error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + error=str(error), + ) + ) + return + + try: + yield from self._message_transformer.transform( + messages=message_stream, + tool_info={ + "icon": self._presentation_provider.get_icon( + tenant_id=dify_ctx.tenant_id, + agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, + ), + "agent_strategy": self.node_data.agent_strategy_name, + }, + parameters_for_log=parameters_for_log, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, + node_type=self.node_type, + node_id=self._node_id, + node_execution_id=self.id, + ) + except PluginDaemonClientSideError as e: + transform_error = AgentMessageTransformError( + f"Failed to transform agent message: {str(e)}", original_error=e + ) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + error=str(transform_error), + ) + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: AgentNodeData, + ) -> Mapping[str, Sequence[str]]: + _ = graph_config # Explicitly mark as unused + result: dict[str, Any] = {} + typed_node_data = node_data + for parameter_name in typed_node_data.agent_parameters: + input = typed_node_data.agent_parameters[parameter_name] + match input.type: + case "mixed" | "constant": + selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value + + result = {node_id + "." + key: value for key, value in result.items()} + + return result diff --git a/api/dify_graph/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py similarity index 85% rename from api/dify_graph/nodes/agent/entities.py rename to api/core/workflow/nodes/agent/entities.py index 9124420f01..59842862ef 100644 --- a/api/dify_graph/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -5,13 +5,15 @@ from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector -from dify_graph.nodes.base.entities import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class AgentNodeData(BaseNodeData): - agent_strategy_provider_name: str # redundancy + type: NodeType = NodeType.AGENT + agent_strategy_provider_name: str agent_strategy_name: str - agent_strategy_label: str # redundancy + agent_strategy_label: str memory: MemoryConfig | None = None # The version of the tool parameter. # If this value is None, it indicates this is a previous version diff --git a/api/dify_graph/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exceptions.py similarity index 90% rename from api/dify_graph/nodes/agent/exc.py rename to api/core/workflow/nodes/agent/exceptions.py index ba2c83d8a6..944f5f0b20 100644 --- a/api/dify_graph/nodes/agent/exc.py +++ b/api/core/workflow/nodes/agent/exceptions.py @@ -119,14 +119,3 @@ class AgentVariableTypeError(AgentNodeError): self.expected_type = expected_type self.actual_type = actual_type super().__init__(message) - - -class AgentMaxIterationError(AgentNodeError): - """Exception raised when the agent exceeds the maximum iteration limit.""" - - def __init__(self, max_iteration: int): - self.max_iteration = max_iteration - super().__init__( - f"Agent exceeded the maximum iteration limit of {max_iteration}. " - f"The agent was unable to complete the task within the allowed number of iterations." - ) diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py new file mode 100644 index 0000000000..317db14d3f --- /dev/null +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping +from typing import Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod +from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.node_events import ( + AgentLogEvent, + NodeEventBase, + NodeRunResult, + StreamChunkEvent, + StreamCompletedEvent, +) +from dify_graph.variables.segments import ArrayFileSegment +from extensions.ext_database import db +from factories import file_factory +from models import ToolFile +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError + + +class AgentMessageTransformer: + def transform( + self, + *, + messages: Generator[ToolInvokeMessage, None, None], + tool_info: Mapping[str, Any], + parameters_for_log: dict[str, Any], + user_id: str, + tenant_id: str, + node_type: NodeType, + node_id: str, + node_execution_id: str, + ) -> Generator[NodeEventBase, None, None]: + from core.plugin.impl.plugin import PluginInstaller + + message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json_list: list[dict | list] = [] + + agent_logs: list[AgentLogEvent] = [] + agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} + llm_usage = LLMUsage.empty_usage() + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + ToolInvokeMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + else: + transfer_method = FileTransferMethod.TOOL_FILE + + tool_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + files.append(file) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert message.meta + + tool_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + ) + elif message.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + text += message.message.text + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk=message.message.text, + is_final=False, + ) + elif message.type == ToolInvokeMessage.MessageType.JSON: + assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + if node_type == NodeType.AGENT: + if isinstance(message.message.json_object, dict): + msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) + llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) + agent_execution_metadata = { + WorkflowNodeExecutionMetadataKey(key): value + for key, value in msg_metadata.items() + if key in WorkflowNodeExecutionMetadataKey.__members__.values() + } + else: + llm_usage = LLMUsage.empty_usage() + agent_execution_metadata = {} + if message.message.json_object: + json_list.append(message.message.json_object) + elif message.type == ToolInvokeMessage.MessageType.LINK: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk=stream_text, + is_final=False, + ) + elif message.type == ToolInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise AgentVariableTypeError( + "When 'stream' is True, 'variable_value' must be a string.", + variable_name=variable_name, + expected_type="str", + actual_type=type(variable_value).__name__, + ) + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield StreamChunkEvent( + selector=[node_id, variable_name], + chunk=variable_value, + is_final=False, + ) + else: + variables[variable_name] = variable_value + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + assert isinstance(message.meta, dict) + if "file" not in message.meta: + raise AgentNodeError("File message is missing 'file' key in meta") + + if not isinstance(message.meta["file"], File): + raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") + files.append(message.meta["file"]) + elif message.type == ToolInvokeMessage.MessageType.LOG: + assert isinstance(message.message, ToolInvokeMessage.LogMessage) + if message.message.metadata: + icon = tool_info.get("icon", "") + dict_metadata = dict(message.message.metadata) + if dict_metadata.get("provider"): + manager = PluginInstaller() + plugins = manager.list_plugins(tenant_id) + try: + current_plugin = next( + plugin + for plugin in plugins + if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] + ) + icon = current_plugin.declaration.icon + except StopIteration: + pass + icon_dark = None + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + user_id, + tenant_id, + ) + if provider.name == dict_metadata["provider"] + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + dict_metadata["icon"] = icon + dict_metadata["icon_dark"] = icon_dark + message.message.metadata = dict_metadata + agent_log = AgentLogEvent( + message_id=message.message.id, + node_execution_id=node_execution_id, + parent_id=message.message.parent_id, + error=message.message.error, + status=message.message.status.value, + data=message.message.data, + label=message.message.label, + metadata=message.message.metadata, + node_id=node_id, + ) + + for log in agent_logs: + if log.message_id == agent_log.message_id: + log.data = agent_log.data + log.status = agent_log.status + log.error = agent_log.error + log.label = agent_log.label + log.metadata = agent_log.metadata + break + else: + agent_logs.append(agent_log) + + yield agent_log + + json_output: list[dict[str, Any] | list[Any]] = [] + if agent_logs: + for log in agent_logs: + json_output.append( + { + "id": log.message_id, + "parent_id": log.parent_id, + "error": log.error, + "status": log.status, + "data": log.data, + "label": log.label, + "metadata": log.metadata, + "node_id": log.node_id, + } + ) + if json_list: + json_output.extend(json_list) + else: + json_output.append({"data": []}) + + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk="", + is_final=True, + ) + + for var_name in variables: + yield StreamChunkEvent( + selector=[node_id, var_name], + chunk="", + is_final=True, + ) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + "text": text, + "usage": jsonable_encoder(llm_usage), + "files": ArrayFileSegment(value=files), + "json": json_output, + **variables, + }, + metadata={ + **agent_execution_metadata, + WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, + WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, + }, + inputs=parameters_for_log, + llm_usage=llm_usage, + ) + ) diff --git a/api/core/workflow/nodes/agent/plugin_strategy_adapter.py b/api/core/workflow/nodes/agent/plugin_strategy_adapter.py new file mode 100644 index 0000000000..1fc427ad6c --- /dev/null +++ b/api/core/workflow/nodes/agent/plugin_strategy_adapter.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from factories.agent_factory import get_plugin_agent_strategy + +from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver, ResolvedAgentStrategy + + +class PluginAgentStrategyResolver(AgentStrategyResolver): + def resolve( + self, + *, + tenant_id: str, + agent_strategy_provider_name: str, + agent_strategy_name: str, + ) -> ResolvedAgentStrategy: + return get_plugin_agent_strategy( + tenant_id=tenant_id, + agent_strategy_provider_name=agent_strategy_provider_name, + agent_strategy_name=agent_strategy_name, + ) + + +class PluginAgentStrategyPresentationProvider(AgentStrategyPresentationProvider): + def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: + from core.plugin.impl.plugin import PluginInstaller + + manager = PluginInstaller() + try: + plugins = manager.list_plugins(tenant_id) + except Exception: + return None + + try: + current_plugin = next( + plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == agent_strategy_provider_name + ) + except StopIteration: + return None + + return current_plugin.declaration.icon diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py new file mode 100644 index 0000000000..2ff7c964b9 --- /dev/null +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import json +from collections.abc import Sequence +from typing import Any, cast + +from packaging.version import Version +from pydantic import ValidationError +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.agent.entities import AgentToolEntity +from core.agent.plugin_entities import AgentStrategyParameter +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance, ModelManager +from core.plugin.entities.request import InvokeCredentials +from core.provider_manager import ProviderManager +from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType +from core.tools.tool_manager import ToolManager +from dify_graph.enums import SystemVariableKey +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.runtime import VariablePool +from dify_graph.variables.segments import StringSegment +from extensions.ext_database import db +from models.model import Conversation + +from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated +from .exceptions import AgentInputTypeError, AgentVariableNotFoundError +from .strategy_protocols import ResolvedAgentStrategy + + +class AgentRuntimeSupport: + def build_parameters( + self, + *, + agent_parameters: Sequence[AgentStrategyParameter], + variable_pool: VariablePool, + node_data: AgentNodeData, + strategy: ResolvedAgentStrategy, + tenant_id: str, + app_id: str, + invoke_from: Any, + for_log: bool = False, + ) -> dict[str, Any]: + agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters} + + result: dict[str, Any] = {} + for parameter_name in node_data.agent_parameters: + parameter = agent_parameters_dictionary.get(parameter_name) + if not parameter: + result[parameter_name] = None + continue + + agent_input = node_data.agent_parameters[parameter_name] + match agent_input.type: + case "variable": + variable = variable_pool.get(agent_input.value) # type: ignore[arg-type] + if variable is None: + raise AgentVariableNotFoundError(str(agent_input.value)) + parameter_value = variable.value + case "mixed" | "constant": + try: + if not isinstance(agent_input.value, str): + parameter_value = json.dumps(agent_input.value, ensure_ascii=False) + else: + parameter_value = str(agent_input.value) + except TypeError: + parameter_value = str(agent_input.value) + + segment_group = variable_pool.convert_template(parameter_value) + parameter_value = segment_group.log if for_log else segment_group.text + try: + if not isinstance(agent_input.value, str): + parameter_value = json.loads(parameter_value) + except json.JSONDecodeError: + parameter_value = parameter_value + case _: + raise AgentInputTypeError(agent_input.type) + + value = parameter_value + if parameter.type == "array[tools]": + value = cast(list[dict[str, Any]], value) + value = [tool for tool in value if tool.get("enabled", False)] + value = self._filter_mcp_type_tool(strategy, value) + for tool in value: + if "schemas" in tool: + tool.pop("schemas") + parameters = tool.get("parameters", {}) + if all(isinstance(v, dict) for _, v in parameters.items()): + params = {} + for key, param in parameters.items(): + if param.get("auto", ParamsAutoGenerated.OPEN) in ( + ParamsAutoGenerated.CLOSE, + 0, + ): + value_param = param.get("value", {}) + if value_param and value_param.get("type", "") == "variable": + variable_selector = value_param.get("value") + if not variable_selector: + raise ValueError("Variable selector is missing for a variable-type parameter.") + + variable = variable_pool.get(variable_selector) + if variable is None: + raise AgentVariableNotFoundError(str(variable_selector)) + + params[key] = variable.value + else: + params[key] = value_param.get("value", "") if value_param is not None else None + else: + params[key] = None + parameters = params + tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()} + tool["parameters"] = parameters + + if not for_log: + if parameter.type == "array[tools]": + value = cast(list[dict[str, Any]], value) + tool_value = [] + for tool in value: + provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) + setting_params = tool.get("settings", {}) + parameters = tool.get("parameters", {}) + manual_input_params = [key for key, value in parameters.items() if value is not None] + + parameters = {**parameters, **setting_params} + entity = AgentToolEntity( + provider_id=tool.get("provider_name", ""), + provider_type=provider_type, + tool_name=tool.get("tool_name", ""), + tool_parameters=parameters, + plugin_unique_identifier=tool.get("plugin_unique_identifier", None), + credential_id=tool.get("credential_id", None), + ) + + extra = tool.get("extra", {}) + + runtime_variable_pool: VariablePool | None = None + if node_data.version != "1" or node_data.tool_node_version is not None: + runtime_variable_pool = variable_pool + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id, + app_id, + entity, + invoke_from, + runtime_variable_pool, + ) + if tool_runtime.entity.description: + tool_runtime.entity.description.llm = ( + extra.get("description", "") or tool_runtime.entity.description.llm + ) + for tool_runtime_params in tool_runtime.entity.parameters: + tool_runtime_params.form = ( + ToolParameter.ToolParameterForm.FORM + if tool_runtime_params.name in manual_input_params + else tool_runtime_params.form + ) + manual_input_value = {} + if tool_runtime.entity.parameters: + manual_input_value = { + key: value for key, value in parameters.items() if key in manual_input_params + } + runtime_parameters = { + **tool_runtime.runtime.runtime_parameters, + **manual_input_value, + } + tool_value.append( + { + **tool_runtime.entity.model_dump(mode="json"), + "runtime_parameters": runtime_parameters, + "credential_id": tool.get("credential_id", None), + "provider_type": provider_type.value, + } + ) + value = tool_value + if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: + value = cast(dict[str, Any], value) + model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value) + history_prompt_messages = [] + if node_data.memory: + memory = self.fetch_memory( + variable_pool=variable_pool, + app_id=app_id, + model_instance=model_instance, + ) + if memory: + prompt_messages = memory.get_history_prompt_messages( + message_limit=node_data.memory.window.size or None + ) + history_prompt_messages = [ + prompt_message.model_dump(mode="json") for prompt_message in prompt_messages + ] + value["history_prompt_messages"] = history_prompt_messages + if model_schema: + model_schema = self._remove_unsupported_model_features_for_old_version(model_schema) + value["entity"] = model_schema.model_dump(mode="json") + else: + value["entity"] = None + result[parameter_name] = value + + return result + + def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials: + credentials = InvokeCredentials() + credentials.tool_credentials = {} + for tool in parameters.get("tools", []): + if not tool.get("credential_id"): + continue + try: + identity = ToolIdentity.model_validate(tool.get("identity", {})) + except ValidationError: + continue + credentials.tool_credentials[identity.provider] = tool.get("credential_id", None) + return credentials + + def fetch_memory( + self, + *, + variable_pool: VariablePool, + app_id: str, + model_instance: ModelInstance, + ) -> TokenBufferMemory | None: + conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + if not isinstance(conversation_id_variable, StringSegment): + return None + conversation_id = conversation_id_variable.value + + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) + conversation = session.scalar(stmt) + if not conversation: + return None + + return TokenBufferMemory(conversation=conversation, model_instance=model_instance) + + def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: + provider_manager = ProviderManager() + provider_model_bundle = provider_manager.get_provider_model_bundle( + tenant_id=tenant_id, + provider=value.get("provider", ""), + model_type=ModelType.LLM, + ) + model_name = value.get("model", "") + model_credentials = provider_model_bundle.configuration.get_current_credentials( + model_type=ModelType.LLM, + model=model_name, + ) + provider_name = provider_model_bundle.configuration.provider.provider + model_type_instance = provider_model_bundle.model_type_instance + model_instance = ModelManager().get_model_instance( + tenant_id=tenant_id, + provider=provider_name, + model_type=ModelType(value.get("model_type", "")), + model=model_name, + ) + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + return model_instance, model_schema + + @staticmethod + def _remove_unsupported_model_features_for_old_version(model_schema: AIModelEntity) -> AIModelEntity: + if model_schema.features: + for feature in model_schema.features[:]: + try: + AgentOldVersionModelFeatures(feature.value) + except ValueError: + model_schema.features.remove(feature) + return model_schema + + @staticmethod + def _filter_mcp_type_tool( + strategy: ResolvedAgentStrategy, + tools: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + meta_version = strategy.meta_version + if meta_version and Version(meta_version) > Version("0.0.1"): + return tools + return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP] diff --git a/api/core/workflow/nodes/agent/strategy_protocols.py b/api/core/workflow/nodes/agent/strategy_protocols.py new file mode 100644 index 0000000000..643d916d15 --- /dev/null +++ b/api/core/workflow/nodes/agent/strategy_protocols.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from collections.abc import Generator, Sequence +from typing import Any, Protocol + +from core.agent.plugin_entities import AgentStrategyParameter +from core.plugin.entities.request import InvokeCredentials +from core.tools.entities.tool_entities import ToolInvokeMessage + + +class ResolvedAgentStrategy(Protocol): + meta_version: str | None + + def get_parameters(self) -> Sequence[AgentStrategyParameter]: ... + + def invoke( + self, + *, + params: dict[str, Any], + user_id: str, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + credentials: InvokeCredentials | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: ... + + +class AgentStrategyResolver(Protocol): + def resolve( + self, + *, + tenant_id: str, + agent_strategy_provider_name: str, + agent_strategy_name: str, + ) -> ResolvedAgentStrategy: ... + + +class AgentStrategyPresentationProvider(Protocol): + def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: ... diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index c259e7ac08..01b309bf54 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -9,9 +9,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer from core.workflow.node_factory import DifyNodeFactory +from core.workflow.node_resolution import resolve_workflow_node_class from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.errors import WorkflowNodeRunFailedError from dify_graph.file.models import File from dify_graph.graph import Graph @@ -23,7 +24,6 @@ from dify_graph.graph_engine.protocols.command_channel import CommandChannel from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from dify_graph.nodes import NodeType from dify_graph.nodes.base.node import Node -from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool @@ -212,7 +212,7 @@ class WorkflowEntry: node_config_data = node_config["data"] # Get node type - node_type = NodeType(node_config_data["type"]) + node_type = node_config_data.type # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -234,8 +234,7 @@ class WorkflowEntry: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - typed_node_config = cast(dict[str, object], node_config) - node = cast(Any, node_factory).create_node(typed_node_config) + node = node_factory.create_node(node_config) node_cls = type(node) try: @@ -344,7 +343,7 @@ class WorkflowEntry: if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}: raise ValueError(f"Node type {node_type} not supported") - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"] + node_cls = resolve_workflow_node_class(node_type=node_type, node_version="1") if not node_cls: raise ValueError(f"Node class not found for node type {node_type}") @@ -371,10 +370,7 @@ class WorkflowEntry: graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init workflow run state - node_config: NodeConfigDict = { - "id": node_id, - "data": cast(NodeConfigData, node_data), - } + node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) node_factory = DifyNodeFactory( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/dify_graph/entities/__init__.py b/api/dify_graph/entities/__init__.py index e73c38c1d3..ef7789c49c 100644 --- a/api/dify_graph/entities/__init__.py +++ b/api/dify_graph/entities/__init__.py @@ -1,11 +1,9 @@ -from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams from .workflow_execution import WorkflowExecution from .workflow_node_execution import WorkflowNodeExecution from .workflow_start_reason import WorkflowStartReason __all__ = [ - "AgentNodeStrategyInit", "GraphInitParams", "WorkflowExecution", "WorkflowNodeExecution", diff --git a/api/dify_graph/entities/agent.py b/api/dify_graph/entities/agent.py deleted file mode 100644 index 2b4d6db76f..0000000000 --- a/api/dify_graph/entities/agent.py +++ /dev/null @@ -1,8 +0,0 @@ -from pydantic import BaseModel - - -class AgentNodeStrategyInit(BaseModel): - """Agent node strategy initialization data.""" - - name: str - icon: str | None = None diff --git a/api/dify_graph/entities/base_node_data.py b/api/dify_graph/entities/base_node_data.py new file mode 100644 index 0000000000..58869a94c2 --- /dev/null +++ b/api/dify_graph/entities/base_node_data.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import json +from abc import ABC +from builtins import type as type_ +from enum import StrEnum +from typing import Any, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from dify_graph.entities.exc import DefaultValueTypeError +from dify_graph.enums import ErrorStrategy, NodeType + +# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. +_NumberType = Union[int, float] + + +class RetryConfig(BaseModel): + """node retry config""" + + max_retries: int = 0 # max retry times + retry_interval: int = 0 # retry interval in milliseconds + retry_enabled: bool = False # whether retry is enabled + + @property + def retry_interval_seconds(self) -> float: + return self.retry_interval / 1000 + + +class DefaultValueType(StrEnum): + STRING = "string" + NUMBER = "number" + OBJECT = "object" + ARRAY_NUMBER = "array[number]" + ARRAY_STRING = "array[string]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILES = "array[file]" + + +class DefaultValue(BaseModel): + value: Any = None + type: DefaultValueType + key: str + + @staticmethod + def _parse_json(value: str): + """Unified JSON parsing handler""" + try: + return json.loads(value) + except json.JSONDecodeError: + raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") + + @staticmethod + def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: + """Unified array type validation""" + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) + + @staticmethod + def _convert_number(value: str) -> float: + """Unified number conversion handler""" + try: + return float(value) + except ValueError: + raise DefaultValueTypeError(f"Cannot convert to number: {value}") + + @model_validator(mode="after") + def validate_value_type(self) -> DefaultValue: + # Type validation configuration + type_validators: dict[DefaultValueType, dict[str, Any]] = { + DefaultValueType.STRING: { + "type": str, + "converter": lambda x: x, + }, + DefaultValueType.NUMBER: { + "type": _NumberType, + "converter": self._convert_number, + }, + DefaultValueType.OBJECT: { + "type": dict, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_NUMBER: { + "type": list, + "element_type": _NumberType, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_STRING: { + "type": list, + "element_type": str, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_OBJECT: { + "type": list, + "element_type": dict, + "converter": self._parse_json, + }, + } + + validator: dict[str, Any] = type_validators.get(self.type, {}) + if not validator: + if self.type == DefaultValueType.ARRAY_FILES: + # Handle files type + return self + raise DefaultValueTypeError(f"Unsupported type: {self.type}") + + # Handle string input cases + if isinstance(self.value, str) and self.type != DefaultValueType.STRING: + self.value = validator["converter"](self.value) + + # Validate base type + if not isinstance(self.value, validator["type"]): + raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") + + # Validate array element types + if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): + raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") + + return self + + +class BaseNodeData(ABC, BaseModel): + # Raw graph payloads are first validated through `NodeConfigDictAdapter`, where + # `node["data"]` is typed as `BaseNodeData` before the concrete node class is known. + # At that boundary, node-specific fields are still "extra" relative to this shared DTO, + # and persisted templates/workflows also carry undeclared compatibility keys such as + # `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive + # here until graph parsing becomes discriminated by node type or those legacy payloads + # are normalized. + model_config = ConfigDict(extra="allow") + + type: NodeType + title: str = "" + desc: str | None = None + version: str = "1" + error_strategy: ErrorStrategy | None = None + default_value: list[DefaultValue] | None = None + retry_config: RetryConfig = Field(default_factory=RetryConfig) + + @property + def default_value_dict(self) -> dict[str, Any]: + if self.default_value: + return {item.key: item.value for item in self.default_value} + return {} + + def __getitem__(self, key: str) -> Any: + """ + Dict-style access without calling model_dump() on every lookup. + Prefer using model fields and Pydantic's extra storage. + """ + # First, check declared model fields + if key in self.__class__.model_fields: + return getattr(self, key) + + # Then, check undeclared compatibility fields stored in Pydantic's extra dict. + extras = getattr(self, "__pydantic_extra__", None) + if extras is None: + extras = getattr(self, "model_extra", None) + if extras is not None and key in extras: + return extras[key] + + raise KeyError(key) + + def get(self, key: str, default: Any = None) -> Any: + """ + Dict-style .get() without calling model_dump() on every lookup. + """ + if key in self.__class__.model_fields: + return getattr(self, key) + + extras = getattr(self, "__pydantic_extra__", None) + if extras is None: + extras = getattr(self, "model_extra", None) + if extras is not None and key in extras: + return extras.get(key, default) + + return default diff --git a/api/dify_graph/nodes/base/exc.py b/api/dify_graph/entities/exc.py similarity index 100% rename from api/dify_graph/nodes/base/exc.py rename to api/dify_graph/entities/exc.py diff --git a/api/dify_graph/entities/graph_config.py b/api/dify_graph/entities/graph_config.py index 209dcfe6bc..36f7b94e82 100644 --- a/api/dify_graph/entities/graph_config.py +++ b/api/dify_graph/entities/graph_config.py @@ -4,21 +4,20 @@ import sys from pydantic import TypeAdapter, with_config +from dify_graph.entities.base_node_data import BaseNodeData + if sys.version_info >= (3, 12): from typing import TypedDict else: from typing_extensions import TypedDict -@with_config(extra="allow") -class NodeConfigData(TypedDict): - type: str - - @with_config(extra="allow") class NodeConfigDict(TypedDict): id: str - data: NodeConfigData + # This is the permissive raw graph boundary. Node factories re-validate `data` + # with the concrete `NodeData` subtype after resolving the node implementation. + data: BaseNodeData NodeConfigDictAdapter = TypeAdapter(NodeConfigDict) diff --git a/api/dify_graph/graph/graph.py b/api/dify_graph/graph/graph.py index 3fe94eb3fd..3eb6bfc359 100644 --- a/api/dify_graph/graph/graph.py +++ b/api/dify_graph/graph/graph.py @@ -8,7 +8,7 @@ from typing import Protocol, cast, final from pydantic import TypeAdapter from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType +from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState from dify_graph.nodes.base.node import Node from libs.typing import is_str @@ -34,7 +34,8 @@ class NodeFactory(Protocol): :param node_config: node configuration dictionary containing type and other data :return: initialized Node instance - :raises ValueError: if node type is unknown or configuration is invalid + :raises ValueError: if node type is unknown or no implementation exists for the resolved version + :raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation """ ... @@ -115,10 +116,7 @@ class Graph: start_node_id = None for nid in root_candidates: node_data = node_configs_map[nid]["data"] - node_type = node_data["type"] - if not isinstance(node_type, str): - continue - if NodeType(node_type).is_start_node: + if node_data.type.is_start_node: start_node_id = nid break @@ -203,6 +201,23 @@ class Graph: return GraphBuilder(graph_cls=cls) + @staticmethod + def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]: + """ + Remove editor-only nodes before `NodeConfigDict` validation. + + Persisted note widgets use a top-level `type == "custom-note"` but leave + `data.type` empty because they are never executable graph nodes. Filter + them while configs are still raw dicts so Pydantic does not validate + their placeholder payloads against `BaseNodeData.type: NodeType`. + """ + filtered_node_configs: list[dict[str, object]] = [] + for node_config in node_configs: + if node_config.get("type", "") == "custom-note": + continue + filtered_node_configs.append(dict(node_config)) + return filtered_node_configs + @classmethod def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None: """ @@ -302,13 +317,13 @@ class Graph: node_configs = graph_config.get("nodes", []) edge_configs = cast(list[dict[str, object]], edge_configs) + node_configs = cast(list[dict[str, object]], node_configs) + node_configs = cls._filter_canvas_only_nodes(node_configs) node_configs = _ListNodeConfigDict.validate_python(node_configs) if not node_configs: raise ValueError("Graph must have at least one node") - node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"] - # Parse node configurations node_configs_map = cls._parse_node_configs(node_configs) diff --git a/api/dify_graph/graph_events/node.py b/api/dify_graph/graph_events/node.py index 21ddf80b64..8552254627 100644 --- a/api/dify_graph/graph_events/node.py +++ b/api/dify_graph/graph_events/node.py @@ -4,7 +4,6 @@ from datetime import datetime from pydantic import Field from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities import AgentNodeStrategyInit from dify_graph.entities.pause_reason import PauseReason from .base import GraphNodeEventBase @@ -13,8 +12,8 @@ from .base import GraphNodeEventBase class NodeRunStartedEvent(GraphNodeEventBase): node_title: str predecessor_node_id: str | None = None - agent_strategy: AgentNodeStrategyInit | None = None start_at: datetime = Field(..., description="node start time") + extras: dict[str, object] = Field(default_factory=dict) # FIXME(-LAN-): only for ToolNode provider_type: str = "" diff --git a/api/dify_graph/model_runtime/entities/message_entities.py b/api/dify_graph/model_runtime/entities/message_entities.py index 9e46d72893..402bfdc606 100644 --- a/api/dify_graph/model_runtime/entities/message_entities.py +++ b/api/dify_graph/model_runtime/entities/message_entities.py @@ -276,7 +276,4 @@ class ToolPromptMessage(PromptMessage): :return: True if prompt message is empty, False otherwise """ - if not super().is_empty() and not self.tool_call_id: - return False - - return True + return super().is_empty() and not self.tool_call_id diff --git a/api/dify_graph/model_runtime/errors/invoke.py b/api/dify_graph/model_runtime/errors/invoke.py index 80cf01fb6c..1a57078b98 100644 --- a/api/dify_graph/model_runtime/errors/invoke.py +++ b/api/dify_graph/model_runtime/errors/invoke.py @@ -4,7 +4,8 @@ class InvokeError(ValueError): description: str | None = None def __init__(self, description: str | None = None): - self.description = description + if description is not None: + self.description = description def __str__(self): return self.description or self.__class__.__name__ diff --git a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py index e168fc11d1..de0677a348 100644 --- a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py +++ b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py @@ -282,7 +282,8 @@ class ModelProviderFactory: all_model_type_models.append(model_schema) simple_provider_schema = provider_schema.to_simple_provider() - simple_provider_schema.models.extend(all_model_type_models) + if model_type: + simple_provider_schema.models = all_model_type_models providers.append(simple_provider_schema) diff --git a/api/dify_graph/nodes/agent/__init__.py b/api/dify_graph/nodes/agent/__init__.py deleted file mode 100644 index 95e7cf895b..0000000000 --- a/api/dify_graph/nodes/agent/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .agent_node import AgentNode - -__all__ = ["AgentNode"] diff --git a/api/dify_graph/nodes/agent/agent_node.py b/api/dify_graph/nodes/agent/agent_node.py deleted file mode 100644 index d770f7afd1..0000000000 --- a/api/dify_graph/nodes/agent/agent_node.py +++ /dev/null @@ -1,762 +0,0 @@ -from __future__ import annotations - -import json -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from packaging.version import Version -from pydantic import ValidationError -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.agent.entities import AgentToolEntity -from core.agent.plugin_entities import AgentStrategyParameter -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance, ModelManager -from core.provider_manager import ProviderManager -from core.tools.entities.tool_entities import ( - ToolIdentity, - ToolInvokeMessage, - ToolParameter, - ToolProviderType, -) -from core.tools.tool_manager import ToolManager -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.enums import ( - NodeType, - SystemVariableKey, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.file import File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ( - AgentLogEvent, - NodeEventBase, - NodeRunResult, - StreamChunkEvent, - StreamCompletedEvent, -) -from dify_graph.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import ArrayFileSegment, StringSegment -from extensions.ext_database import db -from factories import file_factory -from factories.agent_factory import get_plugin_agent_strategy -from models import ToolFile -from models.model import Conversation -from services.tools.builtin_tools_manage_service import BuiltinToolManageService - -from .exc import ( - AgentInputTypeError, - AgentInvocationError, - AgentMessageTransformError, - AgentNodeError, - AgentVariableNotFoundError, - AgentVariableTypeError, - ToolFileNotFoundError, -) - -if TYPE_CHECKING: - from core.agent.strategy.plugin import PluginAgentStrategy - from core.plugin.entities.request import InvokeCredentials - - -class AgentNode(Node[AgentNodeData]): - """ - Agent Node - """ - - node_type = NodeType.AGENT - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator[NodeEventBase, None, None]: - from core.plugin.impl.exc import PluginDaemonClientSideError - - dify_ctx = self.require_dify_context() - - try: - strategy = get_plugin_agent_strategy( - tenant_id=dify_ctx.tenant_id, - agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, - agent_strategy_name=self.node_data.agent_strategy_name, - ) - except Exception as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - error=f"Failed to get agent strategy: {str(e)}", - ), - ) - return - - agent_parameters = strategy.get_parameters() - - # get parameters - parameters = self._generate_agent_parameters( - agent_parameters=agent_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - strategy=strategy, - ) - parameters_for_log = self._generate_agent_parameters( - agent_parameters=agent_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - for_log=True, - strategy=strategy, - ) - credentials = self._generate_credentials(parameters=parameters) - - # get conversation id - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - - try: - message_stream = strategy.invoke( - params=parameters, - user_id=dify_ctx.user_id, - app_id=dify_ctx.app_id, - conversation_id=conversation_id.text if conversation_id else None, - credentials=credentials, - ) - except Exception as e: - error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - error=str(error), - ) - ) - return - - try: - yield from self._transform_message( - messages=message_stream, - tool_info={ - "icon": self.agent_strategy_icon, - "agent_strategy": self.node_data.agent_strategy_name, - }, - parameters_for_log=parameters_for_log, - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - node_type=self.node_type, - node_id=self._node_id, - node_execution_id=self.id, - ) - except PluginDaemonClientSideError as e: - transform_error = AgentMessageTransformError( - f"Failed to transform agent message: {str(e)}", original_error=e - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - error=str(transform_error), - ) - ) - - def _generate_agent_parameters( - self, - *, - agent_parameters: Sequence[AgentStrategyParameter], - variable_pool: VariablePool, - node_data: AgentNodeData, - for_log: bool = False, - strategy: PluginAgentStrategy, - ) -> dict[str, Any]: - """ - Generate parameters based on the given tool parameters, variable pool, and node data. - - Args: - agent_parameters (Sequence[AgentParameter]): The list of agent parameters. - variable_pool (VariablePool): The variable pool containing the variables. - node_data (AgentNodeData): The data associated with the agent node. - - Returns: - Mapping[str, Any]: A dictionary containing the generated parameters. - - """ - agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters} - - result: dict[str, Any] = {} - for parameter_name in node_data.agent_parameters: - parameter = agent_parameters_dictionary.get(parameter_name) - if not parameter: - result[parameter_name] = None - continue - agent_input = node_data.agent_parameters[parameter_name] - match agent_input.type: - case "variable": - variable = variable_pool.get(agent_input.value) # type: ignore - if variable is None: - raise AgentVariableNotFoundError(str(agent_input.value)) - parameter_value = variable.value - case "mixed" | "constant": - # variable_pool.convert_template expects a string template, - # but if passing a dict, convert to JSON string first before rendering - try: - if not isinstance(agent_input.value, str): - parameter_value = json.dumps(agent_input.value, ensure_ascii=False) - else: - parameter_value = str(agent_input.value) - except TypeError: - parameter_value = str(agent_input.value) - segment_group = variable_pool.convert_template(parameter_value) - parameter_value = segment_group.log if for_log else segment_group.text - # variable_pool.convert_template returns a string, - # so we need to convert it back to a dictionary - try: - if not isinstance(agent_input.value, str): - parameter_value = json.loads(parameter_value) - except json.JSONDecodeError: - parameter_value = parameter_value - case _: - raise AgentInputTypeError(agent_input.type) - value = parameter_value - if parameter.type == "array[tools]": - value = cast(list[dict[str, Any]], value) - value = [tool for tool in value if tool.get("enabled", False)] - value = self._filter_mcp_type_tool(strategy, value) - for tool in value: - if "schemas" in tool: - tool.pop("schemas") - parameters = tool.get("parameters", {}) - if all(isinstance(v, dict) for _, v in parameters.items()): - params = {} - for key, param in parameters.items(): - if param.get("auto", ParamsAutoGenerated.OPEN) in ( - ParamsAutoGenerated.CLOSE, - 0, - ): - value_param = param.get("value", {}) - if value_param and value_param.get("type", "") == "variable": - variable_selector = value_param.get("value") - if not variable_selector: - raise ValueError("Variable selector is missing for a variable-type parameter.") - - variable = variable_pool.get(variable_selector) - if variable is None: - raise AgentVariableNotFoundError(str(variable_selector)) - - params[key] = variable.value - else: - params[key] = value_param.get("value", "") if value_param is not None else None - else: - params[key] = None - parameters = params - tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()} - tool["parameters"] = parameters - - if not for_log: - if parameter.type == "array[tools]": - value = cast(list[dict[str, Any]], value) - tool_value = [] - for tool in value: - provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) - setting_params = tool.get("settings", {}) - parameters = tool.get("parameters", {}) - manual_input_params = [key for key, value in parameters.items() if value is not None] - - parameters = {**parameters, **setting_params} - entity = AgentToolEntity( - provider_id=tool.get("provider_name", ""), - provider_type=provider_type, - tool_name=tool.get("tool_name", ""), - tool_parameters=parameters, - plugin_unique_identifier=tool.get("plugin_unique_identifier", None), - credential_id=tool.get("credential_id", None), - ) - - extra = tool.get("extra", {}) - - # This is an issue that caused problems before. - # Logically, we shouldn't use the node_data.version field for judgment - # But for backward compatibility with historical data - # this version field judgment is still preserved here. - runtime_variable_pool: VariablePool | None = None - if node_data.version != "1" or node_data.tool_node_version is not None: - runtime_variable_pool = variable_pool - dify_ctx = self.require_dify_context() - tool_runtime = ToolManager.get_agent_tool_runtime( - dify_ctx.tenant_id, - dify_ctx.app_id, - entity, - dify_ctx.invoke_from, - runtime_variable_pool, - ) - if tool_runtime.entity.description: - tool_runtime.entity.description.llm = ( - extra.get("description", "") or tool_runtime.entity.description.llm - ) - for tool_runtime_params in tool_runtime.entity.parameters: - tool_runtime_params.form = ( - ToolParameter.ToolParameterForm.FORM - if tool_runtime_params.name in manual_input_params - else tool_runtime_params.form - ) - manual_input_value = {} - if tool_runtime.entity.parameters: - manual_input_value = { - key: value for key, value in parameters.items() if key in manual_input_params - } - runtime_parameters = { - **tool_runtime.runtime.runtime_parameters, - **manual_input_value, - } - tool_value.append( - { - **tool_runtime.entity.model_dump(mode="json"), - "runtime_parameters": runtime_parameters, - "credential_id": tool.get("credential_id", None), - "provider_type": provider_type.value, - } - ) - value = tool_value - if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: - value = cast(dict[str, Any], value) - model_instance, model_schema = self._fetch_model(value) - # memory config - history_prompt_messages = [] - if node_data.memory: - memory = self._fetch_memory(model_instance) - if memory: - prompt_messages = memory.get_history_prompt_messages( - message_limit=node_data.memory.window.size or None - ) - history_prompt_messages = [ - prompt_message.model_dump(mode="json") for prompt_message in prompt_messages - ] - value["history_prompt_messages"] = history_prompt_messages - if model_schema: - # remove structured output feature to support old version agent plugin - model_schema = self._remove_unsupported_model_features_for_old_version(model_schema) - value["entity"] = model_schema.model_dump(mode="json") - else: - value["entity"] = None - result[parameter_name] = value - - return result - - def _generate_credentials( - self, - parameters: dict[str, Any], - ) -> InvokeCredentials: - """ - Generate credentials based on the given agent parameters. - """ - from core.plugin.entities.request import InvokeCredentials - - credentials = InvokeCredentials() - - # generate credentials for tools selector - credentials.tool_credentials = {} - for tool in parameters.get("tools", []): - if tool.get("credential_id"): - try: - identity = ToolIdentity.model_validate(tool.get("identity", {})) - credentials.tool_credentials[identity.provider] = tool.get("credential_id", None) - except ValidationError: - continue - return credentials - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: Mapping[str, Any], - ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = AgentNodeData.model_validate(node_data) - - result: dict[str, Any] = {} - for parameter_name in typed_node_data.agent_parameters: - input = typed_node_data.agent_parameters[parameter_name] - match input.type: - case "mixed" | "constant": - selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - case "variable": - result[parameter_name] = input.value - - result = {node_id + "." + key: value for key, value in result.items()} - - return result - - @property - def agent_strategy_icon(self) -> str | None: - """ - Get agent strategy icon - :return: - """ - from core.plugin.impl.plugin import PluginInstaller - - manager = PluginInstaller() - dify_ctx = self.require_dify_context() - plugins = manager.list_plugins(dify_ctx.tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name - ) - icon = current_plugin.declaration.icon - except StopIteration: - icon = None - return icon - - def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None: - # get conversation id - conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID] - ) - if not isinstance(conversation_id_variable, StringSegment): - return None - conversation_id = conversation_id_variable.value - - dify_ctx = self.require_dify_context() - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(Conversation).where( - Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id - ) - conversation = session.scalar(stmt) - - if not conversation: - return None - - memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) - - return memory - - def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: - dify_ctx = self.require_dify_context() - provider_manager = ProviderManager() - provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM - ) - model_name = value.get("model", "") - model_credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=ModelType.LLM, model=model_name - ) - provider_name = provider_model_bundle.configuration.provider.provider - model_type_instance = provider_model_bundle.model_type_instance - model_instance = ModelManager().get_model_instance( - tenant_id=dify_ctx.tenant_id, - provider=provider_name, - model_type=ModelType(value.get("model_type", "")), - model=model_name, - ) - model_schema = model_type_instance.get_model_schema(model_name, model_credentials) - return model_instance, model_schema - - def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity: - if model_schema.features: - for feature in model_schema.features[:]: # Create a copy to safely modify during iteration - try: - AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value - except ValueError: - model_schema.features.remove(feature) - return model_schema - - def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: - """ - Filter MCP type tool - :param strategy: plugin agent strategy - :param tool: tool - :return: filtered tool dict - """ - meta_version = strategy.meta_version - if meta_version and Version(meta_version) > Version("0.0.1"): - return tools - else: - return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP] - - def _transform_message( - self, - messages: Generator[ToolInvokeMessage, None, None], - tool_info: Mapping[str, Any], - parameters_for_log: dict[str, Any], - user_id: str, - tenant_id: str, - node_type: NodeType, - node_id: str, - node_execution_id: str, - ) -> Generator[NodeEventBase, None, None]: - """ - Convert ToolInvokeMessages into tuple[plain_text, files] - """ - # transform message and handle file storage - from core.plugin.impl.plugin import PluginInstaller - - message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=messages, - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - ) - - text = "" - files: list[File] = [] - json_list: list[dict | list] = [] - - agent_logs: list[AgentLogEvent] = [] - agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} - llm_usage = LLMUsage.empty_usage() - variables: dict[str, Any] = {} - - for message in message_stream: - if message.type in { - ToolInvokeMessage.MessageType.IMAGE_LINK, - ToolInvokeMessage.MessageType.BINARY_LINK, - ToolInvokeMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - - url = message.message.text - if message.meta: - transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - else: - transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] - - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileNotFoundError(tool_file_id) - - mapping = { - "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - files.append(file) - elif message.type == ToolInvokeMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - assert message.meta - - tool_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileNotFoundError(tool_file_id) - - mapping = { - "tool_file_id": tool_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - ) - elif message.type == ToolInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == ToolInvokeMessage.MessageType.JSON: - assert isinstance(message.message, ToolInvokeMessage.JsonMessage) - if node_type == NodeType.AGENT: - if isinstance(message.message.json_object, dict): - msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) - llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) - agent_execution_metadata = { - WorkflowNodeExecutionMetadataKey(key): value - for key, value in msg_metadata.items() - if key in WorkflowNodeExecutionMetadataKey.__members__.values() - } - else: - msg_metadata = {} - llm_usage = LLMUsage.empty_usage() - agent_execution_metadata = {} - if message.message.json_object: - json_list.append(message.message.json_object) - elif message.type == ToolInvokeMessage.MessageType.LINK: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - stream_text = f"Link: {message.message.text}\n" - text += stream_text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == ToolInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolInvokeMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise AgentVariableTypeError( - "When 'stream' is True, 'variable_value' must be a string.", - variable_name=variable_name, - expected_type="str", - actual_type=type(variable_value).__name__, - ) - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value - - yield StreamChunkEvent( - selector=[node_id, variable_name], - chunk=variable_value, - is_final=False, - ) - else: - variables[variable_name] = variable_value - elif message.type == ToolInvokeMessage.MessageType.FILE: - assert message.meta is not None - assert isinstance(message.meta, dict) - # Validate that meta contains a 'file' key - if "file" not in message.meta: - raise AgentNodeError("File message is missing 'file' key in meta") - - # Validate that the file is an instance of File - if not isinstance(message.meta["file"], File): - raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") - files.append(message.meta["file"]) - elif message.type == ToolInvokeMessage.MessageType.LOG: - assert isinstance(message.message, ToolInvokeMessage.LogMessage) - if message.message.metadata: - icon = tool_info.get("icon", "") - dict_metadata = dict(message.message.metadata) - if dict_metadata.get("provider"): - manager = PluginInstaller() - plugins = manager.list_plugins(tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] - ) - icon = current_plugin.declaration.icon - except StopIteration: - pass - icon_dark = None - try: - builtin_tool = next( - provider - for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - ) - if provider.name == dict_metadata["provider"] - ) - icon = builtin_tool.icon - icon_dark = builtin_tool.icon_dark - except StopIteration: - pass - - dict_metadata["icon"] = icon - dict_metadata["icon_dark"] = icon_dark - message.message.metadata = dict_metadata - agent_log = AgentLogEvent( - message_id=message.message.id, - node_execution_id=node_execution_id, - parent_id=message.message.parent_id, - error=message.message.error, - status=message.message.status.value, - data=message.message.data, - label=message.message.label, - metadata=message.message.metadata, - node_id=node_id, - ) - - # check if the agent log is already in the list - for log in agent_logs: - if log.message_id == agent_log.message_id: - # update the log - log.data = agent_log.data - log.status = agent_log.status - log.error = agent_log.error - log.label = agent_log.label - log.metadata = agent_log.metadata - break - else: - agent_logs.append(agent_log) - - yield agent_log - - # Add agent_logs to outputs['json'] to ensure frontend can access thinking process - json_output: list[dict[str, Any] | list[Any]] = [] - - # Step 1: append each agent log as its own dict. - if agent_logs: - for log in agent_logs: - json_output.append( - { - "id": log.message_id, - "parent_id": log.parent_id, - "error": log.error, - "status": log.status, - "data": log.data, - "label": log.label, - "metadata": log.metadata, - "node_id": log.node_id, - } - ) - # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] - if json_list: - json_output.extend(json_list) - else: - json_output.append({"data": []}) - - # Send final chunk events for all streamed outputs - # Final chunk for text stream - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk="", - is_final=True, - ) - - # Final chunks for any streamed variables - for var_name in variables: - yield StreamChunkEvent( - selector=[node_id, var_name], - chunk="", - is_final=True, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "text": text, - "usage": jsonable_encoder(llm_usage), - "files": ArrayFileSegment(value=files), - "json": json_output, - **variables, - }, - metadata={ - **agent_execution_metadata, - WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, - }, - inputs=parameters_for_log, - llm_usage=llm_usage, - ) - ) diff --git a/api/dify_graph/nodes/answer/answer_node.py b/api/dify_graph/nodes/answer/answer_node.py index d07b9c8062..c829b892cc 100644 --- a/api/dify_graph/nodes/answer/answer_node.py +++ b/api/dify_graph/nodes/answer/answer_node.py @@ -48,12 +48,10 @@ class AnswerNode(Node[AnswerNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: AnswerNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = AnswerNodeData.model_validate(node_data) - - variable_template_parser = VariableTemplateParser(template=typed_node_data.answer) + _ = graph_config # Explicitly mark as unused + variable_template_parser = VariableTemplateParser(template=node_data.answer) variable_selectors = variable_template_parser.extract_variable_selectors() variable_mapping = {} diff --git a/api/dify_graph/nodes/answer/entities.py b/api/dify_graph/nodes/answer/entities.py index 06927cd71e..3cc1d6572e 100644 --- a/api/dify_graph/nodes/answer/entities.py +++ b/api/dify_graph/nodes/answer/entities.py @@ -3,7 +3,8 @@ from enum import StrEnum, auto from pydantic import BaseModel, Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class AnswerNodeData(BaseNodeData): @@ -11,6 +12,7 @@ class AnswerNodeData(BaseNodeData): Answer Node Data. """ + type: NodeType = NodeType.ANSWER answer: str = Field(..., description="answer template string") diff --git a/api/dify_graph/nodes/base/__init__.py b/api/dify_graph/nodes/base/__init__.py index f83df0e323..036e25895d 100644 --- a/api/dify_graph/nodes/base/__init__.py +++ b/api/dify_graph/nodes/base/__init__.py @@ -1,4 +1,4 @@ -from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData +from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState from .usage_tracking_mixin import LLMUsageTrackingMixin __all__ = [ @@ -6,6 +6,5 @@ __all__ = [ "BaseIterationState", "BaseLoopNodeData", "BaseLoopState", - "BaseNodeData", "LLMUsageTrackingMixin", ] diff --git a/api/dify_graph/nodes/base/entities.py b/api/dify_graph/nodes/base/entities.py index 956fa59e78..4f8b2682e1 100644 --- a/api/dify_graph/nodes/base/entities.py +++ b/api/dify_graph/nodes/base/entities.py @@ -1,31 +1,12 @@ from __future__ import annotations -import json -from abc import ABC -from builtins import type as type_ from collections.abc import Sequence from enum import StrEnum -from typing import Any, Union +from typing import Any -from pydantic import BaseModel, field_validator, model_validator +from pydantic import BaseModel, field_validator -from dify_graph.enums import ErrorStrategy - -from .exc import DefaultValueTypeError - -_NumberType = Union[int, float] - - -class RetryConfig(BaseModel): - """node retry config""" - - max_retries: int = 0 # max retry times - retry_interval: int = 0 # retry interval in milliseconds - retry_enabled: bool = False # whether retry is enabled - - @property - def retry_interval_seconds(self) -> float: - return self.retry_interval / 1000 +from dify_graph.entities.base_node_data import BaseNodeData class VariableSelector(BaseModel): @@ -76,112 +57,6 @@ class OutputVariableEntity(BaseModel): return v -class DefaultValueType(StrEnum): - STRING = "string" - NUMBER = "number" - OBJECT = "object" - ARRAY_NUMBER = "array[number]" - ARRAY_STRING = "array[string]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILES = "array[file]" - - -class DefaultValue(BaseModel): - value: Any = None - type: DefaultValueType - key: str - - @staticmethod - def _parse_json(value: str): - """Unified JSON parsing handler""" - try: - return json.loads(value) - except json.JSONDecodeError: - raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") - - @staticmethod - def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: - """Unified array type validation""" - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) - - @staticmethod - def _convert_number(value: str) -> float: - """Unified number conversion handler""" - try: - return float(value) - except ValueError: - raise DefaultValueTypeError(f"Cannot convert to number: {value}") - - @model_validator(mode="after") - def validate_value_type(self) -> DefaultValue: - # Type validation configuration - type_validators: dict[DefaultValueType, dict[str, Any]] = { - DefaultValueType.STRING: { - "type": str, - "converter": lambda x: x, - }, - DefaultValueType.NUMBER: { - "type": _NumberType, - "converter": self._convert_number, - }, - DefaultValueType.OBJECT: { - "type": dict, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_NUMBER: { - "type": list, - "element_type": _NumberType, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_STRING: { - "type": list, - "element_type": str, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_OBJECT: { - "type": list, - "element_type": dict, - "converter": self._parse_json, - }, - } - - validator: dict[str, Any] = type_validators.get(self.type, {}) - if not validator: - if self.type == DefaultValueType.ARRAY_FILES: - # Handle files type - return self - raise DefaultValueTypeError(f"Unsupported type: {self.type}") - - # Handle string input cases - if isinstance(self.value, str) and self.type != DefaultValueType.STRING: - self.value = validator["converter"](self.value) - - # Validate base type - if not isinstance(self.value, validator["type"]): - raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") - - # Validate array element types - if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): - raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") - - return self - - -class BaseNodeData(ABC, BaseModel): - title: str - desc: str | None = None - version: str = "1" - error_strategy: ErrorStrategy | None = None - default_value: list[DefaultValue] | None = None - retry_config: RetryConfig = RetryConfig() - - @property - def default_value_dict(self) -> dict[str, Any]: - if self.default_value: - return {item.key: item.value for item in self.default_value} - return {} - - class BaseIterationNodeData(BaseNodeData): start_node_id: str | None = None diff --git a/api/dify_graph/nodes/base/node.py b/api/dify_graph/nodes/base/node.py index 1f99a0a6e2..2044b09333 100644 --- a/api/dify_graph/nodes/base/node.py +++ b/api/dify_graph/nodes/base/node.py @@ -11,7 +11,9 @@ from types import MappingProxyType from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin from uuid import uuid4 -from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams +from dify_graph.entities import GraphInitParams +from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import ( ErrorStrategy, @@ -62,8 +64,6 @@ from dify_graph.node_events import ( from dify_graph.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now -from .entities import BaseNodeData, RetryConfig - NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) _MISSING_RUN_CONTEXT_VALUE = object() @@ -153,11 +153,11 @@ class Node(Generic[NodeDataT]): Later, in __init__: :: - config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate() - │ - ▼ - CodeNodeData instance - (stored in self._node_data) + config["data"] ──► _node_data_type.model_validate(..., from_attributes=True) + │ + ▼ + CodeNodeData instance + (stored in self._node_data) Example: class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted @@ -241,7 +241,7 @@ class Node(Generic[NodeDataT]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, ) -> None: @@ -254,22 +254,21 @@ class Node(Generic[NodeDataT]): self.graph_runtime_state = graph_runtime_state self.state: NodeState = NodeState.UNKNOWN # node execution state - node_id = config.get("id") - if not node_id: - raise ValueError("Node ID is required.") + node_id = config["id"] self._node_id = node_id self._node_execution_id: str = "" self._start_at = naive_utc_now() - raw_node_data = config.get("data") or {} - if not isinstance(raw_node_data, Mapping): - raise ValueError("Node config data must be a mapping.") - - self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data) + self._node_data = self.validate_node_data(config["data"]) self.post_init() + @classmethod + def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT: + """Validate shared graph node payloads against the subclass-declared NodeData model.""" + return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True)) + def post_init(self) -> None: """Optional hook for subclasses requiring extra initialization.""" return @@ -342,9 +341,6 @@ class Node(Generic[NodeDataT]): return None return str(execution_id) - def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT: - return cast(NodeDataT, self._node_data_type.model_validate(data)) - @abstractmethod def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: """ @@ -353,6 +349,10 @@ class Node(Generic[NodeDataT]): """ raise NotImplementedError + def populate_start_event(self, event: NodeRunStartedEvent) -> None: + """Allow subclasses to enrich the started event without cross-node imports in the base class.""" + _ = event + def run(self) -> Generator[GraphNodeEventBase, None, None]: execution_id = self.ensure_execution_id() self._start_at = naive_utc_now() @@ -366,41 +366,10 @@ class Node(Generic[NodeDataT]): in_iteration_id=None, start_at=self._start_at, ) - - # === FIXME(-LAN-): Needs to refactor. - from dify_graph.nodes.tool.tool_node import ToolNode - - if isinstance(self, ToolNode): - start_event.provider_id = getattr(self.node_data, "provider_id", "") - start_event.provider_type = getattr(self.node_data, "provider_type", "") - - from dify_graph.nodes.datasource.datasource_node import DatasourceNode - - if isinstance(self, DatasourceNode): - plugin_id = getattr(self.node_data, "plugin_id", "") - provider_name = getattr(self.node_data, "provider_name", "") - - start_event.provider_id = f"{plugin_id}/{provider_name}" - start_event.provider_type = getattr(self.node_data, "provider_type", "") - - from dify_graph.nodes.trigger_plugin.trigger_event_node import TriggerEventNode - - if isinstance(self, TriggerEventNode): - start_event.provider_id = getattr(self.node_data, "provider_id", "") - start_event.provider_type = getattr(self.node_data, "provider_type", "") - - from typing import cast - - from dify_graph.nodes.agent.agent_node import AgentNode - from dify_graph.nodes.agent.entities import AgentNodeData - - if isinstance(self, AgentNode): - start_event.agent_strategy = AgentNodeStrategyInit( - name=cast(AgentNodeData, self.node_data).agent_strategy_name, - icon=self.agent_strategy_icon, - ) - - # === + try: + self.populate_start_event(start_event) + except Exception: + logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True) yield start_event try: @@ -442,7 +411,7 @@ class Node(Generic[NodeDataT]): cls, *, graph_config: Mapping[str, Any], - config: Mapping[str, Any], + config: NodeConfigDict, ) -> Mapping[str, Sequence[str]]: """Extracts references variable selectors from node configuration. @@ -480,13 +449,12 @@ class Node(Generic[NodeDataT]): :param config: node config :return: """ - node_id = config.get("id") - if not node_id: - raise ValueError("Node ID is required when extracting variable selector to variable mapping.") - - # Pass raw dict data instead of creating NodeData instance + node_id = config["id"] + node_data = cls.validate_node_data(config["data"]) data = cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, node_id=node_id, node_data=config.get("data", {}) + graph_config=graph_config, + node_id=node_id, + node_data=node_data, ) return data @@ -496,7 +464,7 @@ class Node(Generic[NodeDataT]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: NodeDataT, ) -> Mapping[str, Sequence[str]]: return {} @@ -520,10 +488,8 @@ class Node(Generic[NodeDataT]): @abstractmethod def version(cls) -> str: """`node_version` returns the version of current node type.""" - # NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`. - # - # If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING` - # in `api/dify_graph/nodes/__init__.py`. + # NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so + # `Node.get_node_type_classes_mapping()` can resolve numeric versions and `latest`. raise NotImplementedError("subclasses of BaseNode must implement `version` method.") @classmethod @@ -531,7 +497,9 @@ class Node(Generic[NodeDataT]): """Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry. Import all modules under dify_graph.nodes so subclasses register themselves on import. - Then we return a readonly view of the registry to avoid accidental mutation. + Callers that rely on workflow-local nodes defined outside `dify_graph.nodes` must import + those modules before invoking this method so they can register through `__init_subclass__`. + We then return a readonly view of the registry to avoid accidental mutation. """ # Import all node modules to ensure they are loaded (thus registered) import dify_graph.nodes as _nodes_pkg diff --git a/api/dify_graph/nodes/code/code_node.py b/api/dify_graph/nodes/code/code_node.py index 83e72deea9..ac8d6463b9 100644 --- a/api/dify_graph/nodes/code/code_node.py +++ b/api/dify_graph/nodes/code/code_node.py @@ -3,6 +3,7 @@ from decimal import Decimal from textwrap import dedent from typing import TYPE_CHECKING, Any, Protocol, cast +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -77,7 +78,7 @@ class CodeNode(Node[CodeNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -466,15 +467,12 @@ class CodeNode(Node[CodeNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: CodeNodeData, ) -> Mapping[str, Sequence[str]]: _ = graph_config # Explicitly mark as unused - # Create typed NodeData from dict - typed_node_data = CodeNodeData.model_validate(node_data) - return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in typed_node_data.variables + for variable_selector in node_data.variables } @property diff --git a/api/dify_graph/nodes/code/entities.py b/api/dify_graph/nodes/code/entities.py index 9e161c29d0..25e46226e1 100644 --- a/api/dify_graph/nodes/code/entities.py +++ b/api/dify_graph/nodes/code/entities.py @@ -3,7 +3,8 @@ from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.base.entities import VariableSelector from dify_graph.variables.types import SegmentType @@ -39,6 +40,8 @@ class CodeNodeData(BaseNodeData): Code Node Data. """ + type: NodeType = NodeType.CODE + class Output(BaseModel): type: Annotated[SegmentType, AfterValidator(_validate_type)] children: dict[str, "CodeNodeData.Output"] | None = None diff --git a/api/dify_graph/nodes/datasource/datasource_node.py b/api/dify_graph/nodes/datasource/datasource_node.py index b97394744e..62dcb2924f 100644 --- a/api/dify_graph/nodes/datasource/datasource_node.py +++ b/api/dify_graph/nodes/datasource/datasource_node.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey from dify_graph.node_events import NodeRunResult, StreamCompletedEvent @@ -34,7 +35,7 @@ class DatasourceNode(Node[DatasourceNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", datasource_manager: DatasourceManagerProtocol, @@ -47,6 +48,10 @@ class DatasourceNode(Node[DatasourceNodeData]): ) self.datasource_manager = datasource_manager + def populate_start_event(self, event) -> None: + event.provider_id = f"{self.node_data.plugin_id}/{self.node_data.provider_name}" + event.provider_type = self.node_data.provider_type + def _run(self) -> Generator: """ Run the datasource node @@ -181,7 +186,7 @@ class DatasourceNode(Node[DatasourceNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: DatasourceNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -190,11 +195,10 @@ class DatasourceNode(Node[DatasourceNodeData]): :param node_data: node data :return: """ - typed_node_data = DatasourceNodeData.model_validate(node_data) result = {} - if typed_node_data.datasource_parameters: - for parameter_name in typed_node_data.datasource_parameters: - input = typed_node_data.datasource_parameters[parameter_name] + if node_data.datasource_parameters: + for parameter_name in node_data.datasource_parameters: + input = node_data.datasource_parameters[parameter_name] match input.type: case "mixed": assert isinstance(input.value, str) diff --git a/api/dify_graph/nodes/datasource/entities.py b/api/dify_graph/nodes/datasource/entities.py index ba49e65f31..38275ac158 100644 --- a/api/dify_graph/nodes/datasource/entities.py +++ b/api/dify_graph/nodes/datasource/entities.py @@ -3,7 +3,8 @@ from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from dify_graph.nodes.base.entities import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class DatasourceEntity(BaseModel): @@ -16,6 +17,8 @@ class DatasourceEntity(BaseModel): class DatasourceNodeData(BaseNodeData, DatasourceEntity): + type: NodeType = NodeType.DATASOURCE + class DatasourceInput(BaseModel): # TODO: check this type value: Union[Any, list[str]] diff --git a/api/dify_graph/nodes/document_extractor/entities.py b/api/dify_graph/nodes/document_extractor/entities.py index f4949d0df8..9f42d2e605 100644 --- a/api/dify_graph/nodes/document_extractor/entities.py +++ b/api/dify_graph/nodes/document_extractor/entities.py @@ -1,10 +1,12 @@ from collections.abc import Sequence from dataclasses import dataclass -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class DocumentExtractorNodeData(BaseNodeData): + type: NodeType = NodeType.DOCUMENT_EXTRACTOR variable_selector: Sequence[str] diff --git a/api/dify_graph/nodes/document_extractor/node.py b/api/dify_graph/nodes/document_extractor/node.py index c26b18aac9..fe51b1963e 100644 --- a/api/dify_graph/nodes/document_extractor/node.py +++ b/api/dify_graph/nodes/document_extractor/node.py @@ -21,6 +21,7 @@ from docx.oxml.text.paragraph import CT_P from docx.table import Table from docx.text.paragraph import Paragraph +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, file_manager from dify_graph.node_events import NodeRunResult @@ -54,7 +55,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -136,12 +137,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: DocumentExtractorNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = DocumentExtractorNodeData.model_validate(node_data) - - return {node_id + ".files": typed_node_data.variable_selector} + _ = graph_config # Explicitly mark as unused + return {node_id + ".files": node_data.variable_selector} def _extract_text_by_mime_type( diff --git a/api/dify_graph/nodes/end/entities.py b/api/dify_graph/nodes/end/entities.py index a410087214..69cd1dd8f5 100644 --- a/api/dify_graph/nodes/end/entities.py +++ b/api/dify_graph/nodes/end/entities.py @@ -1,6 +1,8 @@ from pydantic import BaseModel, Field -from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType +from dify_graph.nodes.base.entities import OutputVariableEntity class EndNodeData(BaseNodeData): @@ -8,6 +10,7 @@ class EndNodeData(BaseNodeData): END Node Data. """ + type: NodeType = NodeType.END outputs: list[OutputVariableEntity] diff --git a/api/dify_graph/nodes/http_request/entities.py b/api/dify_graph/nodes/http_request/entities.py index a5564689f8..46e08ea1a0 100644 --- a/api/dify_graph/nodes/http_request/entities.py +++ b/api/dify_graph/nodes/http_request/entities.py @@ -8,7 +8,8 @@ import charset_normalizer import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config" @@ -89,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData): Code Node Data. """ + type: NodeType = NodeType.HTTP_REQUEST method: Literal[ "get", "post", diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py index 2e48d5502a..3895ae92c0 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -3,6 +3,7 @@ import mimetypes from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod from dify_graph.node_events import NodeRunResult @@ -37,7 +38,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -163,18 +164,15 @@ class HttpRequestNode(Node[HttpRequestNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: HttpRequestNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = HttpRequestNodeData.model_validate(node_data) - selectors: list[VariableSelector] = [] - selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url) - selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers) - selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params) - if typed_node_data.body: - body_type = typed_node_data.body.type - data = typed_node_data.body.data + selectors += variable_template_parser.extract_selectors_from_template(node_data.url) + selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) + selectors += variable_template_parser.extract_selectors_from_template(node_data.params) + if node_data.body: + body_type = node_data.body.type + data = node_data.body.data match body_type: case "none": pass diff --git a/api/dify_graph/nodes/human_input/entities.py b/api/dify_graph/nodes/human_input/entities.py index 5616949dcc..642c2143e5 100644 --- a/api/dify_graph/nodes/human_input/entities.py +++ b/api/dify_graph/nodes/human_input/entities.py @@ -10,7 +10,8 @@ from typing import Annotated, Any, ClassVar, Literal, Self from pydantic import BaseModel, Field, field_validator, model_validator -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser from dify_graph.runtime import VariablePool from dify_graph.variables.consts import SELECTORS_LENGTH @@ -71,8 +72,8 @@ class EmailDeliveryConfig(BaseModel): body: str debug_mode: bool = False - def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig": - if not user_id: + def with_debug_recipient(self, user_id: str | None) -> "EmailDeliveryConfig": + if user_id is None: debug_recipients = EmailRecipients(whole_workspace=False, items=[]) return self.model_copy(update={"recipients": debug_recipients}) debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)]) @@ -140,7 +141,7 @@ def apply_debug_email_recipient( method: DeliveryChannelConfig, *, enabled: bool, - user_id: str, + user_id: str | None, ) -> DeliveryChannelConfig: if not enabled: return method @@ -148,7 +149,7 @@ def apply_debug_email_recipient( return method if not method.config.debug_mode: return method - debug_config = method.config.with_debug_recipient(user_id or "") + debug_config = method.config.with_debug_recipient(user_id) return method.model_copy(update={"config": debug_config}) @@ -214,6 +215,7 @@ class UserAction(BaseModel): class HumanInputNodeData(BaseNodeData): """Human Input node data.""" + type: NodeType = NodeType.HUMAN_INPUT delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) form_content: str = "" inputs: list[FormInput] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py index 03c2d17b1d..3a167d122b 100644 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ b/api/dify_graph/nodes/human_input/human_input_node.py @@ -3,6 +3,7 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import ( @@ -63,7 +64,7 @@ class HumanInputNode(Node[HumanInputNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", form_repository: HumanInputFormRepository, @@ -348,7 +349,7 @@ class HumanInputNode(Node[HumanInputNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: HumanInputNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selectors referenced in form content and input default values. @@ -357,5 +358,4 @@ class HumanInputNode(Node[HumanInputNodeData]): 1. Variables referenced in form_content ({{#node_name.var_name#}}) 2. Variables referenced in input default values """ - validated_node_data = HumanInputNodeData.model_validate(node_data) - return validated_node_data.extract_variable_selector_to_variable_mapping(node_id) + return node_data.extract_variable_selector_to_variable_mapping(node_id) diff --git a/api/dify_graph/nodes/if_else/entities.py b/api/dify_graph/nodes/if_else/entities.py index 4733944039..c9bb1cdc7f 100644 --- a/api/dify_graph/nodes/if_else/entities.py +++ b/api/dify_graph/nodes/if_else/entities.py @@ -2,7 +2,8 @@ from typing import Literal from pydantic import BaseModel, Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.utils.condition.entities import Condition @@ -11,6 +12,8 @@ class IfElseNodeData(BaseNodeData): If Else Node Data. """ + type: NodeType = NodeType.IF_ELSE + class Case(BaseModel): """ Case entity representing a single logical condition group diff --git a/api/dify_graph/nodes/if_else/if_else_node.py b/api/dify_graph/nodes/if_else/if_else_node.py index 3c5a33e2b7..4b6d30c279 100644 --- a/api/dify_graph/nodes/if_else/if_else_node.py +++ b/api/dify_graph/nodes/if_else/if_else_node.py @@ -97,13 +97,11 @@ class IfElseNode(Node[IfElseNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: IfElseNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = IfElseNodeData.model_validate(node_data) - var_mapping: dict[str, list[str]] = {} - for case in typed_node_data.cases or []: + _ = graph_config # Explicitly mark as unused + for case in node_data.cases or []: for condition in case.conditions: key = f"{node_id}.#{'.'.join(condition.variable_selector)}#" var_mapping[key] = condition.variable_selector diff --git a/api/dify_graph/nodes/iteration/entities.py b/api/dify_graph/nodes/iteration/entities.py index a31b05463e..6d61c12352 100644 --- a/api/dify_graph/nodes/iteration/entities.py +++ b/api/dify_graph/nodes/iteration/entities.py @@ -3,7 +3,9 @@ from typing import Any from pydantic import Field -from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType +from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState class ErrorHandleMode(StrEnum): @@ -17,6 +19,7 @@ class IterationNodeData(BaseIterationNodeData): Iteration Node Data. """ + type: NodeType = NodeType.ITERATION parent_loop_id: str | None = None # redundant field, not used currently iterator_selector: list[str] # variable selector output_selector: list[str] # output selector @@ -31,7 +34,7 @@ class IterationStartNodeData(BaseNodeData): Iteration Start Node Data. """ - pass + type: NodeType = NodeType.ITERATION_START class IterationState(BaseIterationState): diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py index 6d26cbfce4..1d626f4bd6 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast from typing_extensions import TypeIs from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import ( NodeExecutionType, NodeType, @@ -460,21 +461,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: IterationNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = IterationNodeData.model_validate(node_data) - variable_mapping: dict[str, Sequence[str]] = { - f"{node_id}.input_selector": typed_node_data.iterator_selector, + f"{node_id}.input_selector": node_data.iterator_selector, } iteration_node_ids = set() # Find all nodes that belong to this loop nodes = graph_config.get("nodes", []) for node in nodes: - node_data = node.get("data", {}) - if node_data.get("iteration_id") == node_id: + node_config_data = node.get("data", {}) + if node_config_data.get("iteration_id") == node_id: in_iteration_node_id = node.get("id") if in_iteration_node_id: iteration_node_ids.add(in_iteration_node_id) @@ -488,16 +486,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): # variable selector to variable mapping try: # Get node class - from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + from dify_graph.nodes.node_mapping import get_node_type_classes_mapping - node_type = NodeType(sub_node_config.get("data", {}).get("type")) - if node_type not in NODE_TYPE_CLASSES_MAPPING: + typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) + node_type = typed_sub_node_config["data"].type + node_mapping = get_node_type_classes_mapping() + if node_type not in node_mapping: continue - node_version = sub_node_config.get("data", {}).get("version", "1") - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] + node_version = str(typed_sub_node_config["data"].version) + node_cls = node_mapping[node_type][node_version] sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=sub_node_config + graph_config=graph_config, config=typed_sub_node_config ) sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) except NotImplementedError: diff --git a/api/dify_graph/nodes/knowledge_index/entities.py b/api/dify_graph/nodes/knowledge_index/entities.py index 493b5eadd8..d88ee8e3af 100644 --- a/api/dify_graph/nodes/knowledge_index/entities.py +++ b/api/dify_graph/nodes/knowledge_index/entities.py @@ -3,7 +3,8 @@ from typing import Literal, Union from pydantic import BaseModel from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class RerankingModelConfig(BaseModel): @@ -155,7 +156,7 @@ class KnowledgeIndexNodeData(BaseNodeData): Knowledge index Node Data. """ - type: str = "knowledge-index" + type: NodeType = NodeType.KNOWLEDGE_INDEX chunk_structure: str index_chunk_variable_selector: list[str] indexing_technique: str | None = None diff --git a/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py b/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py index eeb4f3c229..3c4fe2344c 100644 --- a/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py +++ b/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py @@ -2,6 +2,7 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey from dify_graph.node_events import NodeRunResult @@ -30,7 +31,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", index_processor: IndexProcessorProtocol, diff --git a/api/dify_graph/nodes/knowledge_retrieval/entities.py b/api/dify_graph/nodes/knowledge_retrieval/entities.py index c3059897c7..8f226b9785 100644 --- a/api/dify_graph/nodes/knowledge_retrieval/entities.py +++ b/api/dify_graph/nodes/knowledge_retrieval/entities.py @@ -3,7 +3,8 @@ from typing import Literal from pydantic import BaseModel, Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig @@ -113,7 +114,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData): Knowledge retrieval Node Data. """ - type: str = "knowledge-retrieval" + type: NodeType = NodeType.KNOWLEDGE_RETRIEVAL query_variable_selector: list[str] | None | str = None query_attachment_selector: list[str] | None | str = None dataset_ids: list[str] diff --git a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py index c67e14ce17..61c9614340 100644 --- a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( NodeType, WorkflowNodeExecutionMetadataKey, @@ -49,7 +50,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", rag_retrieval: RAGRetrievalProtocol, @@ -301,15 +302,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: KnowledgeRetrievalNodeData, ) -> Mapping[str, Sequence[str]]: # graph_config is not used in this node type - # Create typed NodeData from dict - typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data) - variable_mapping = {} - if typed_node_data.query_variable_selector: - variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector - if typed_node_data.query_attachment_selector: - variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector + if node_data.query_variable_selector: + variable_mapping[node_id + ".query"] = node_data.query_variable_selector + if node_data.query_attachment_selector: + variable_mapping[node_id + ".queryAttachment"] = node_data.query_attachment_selector return variable_mapping diff --git a/api/dify_graph/nodes/list_operator/entities.py b/api/dify_graph/nodes/list_operator/entities.py index 0fdd85f210..a91cfab8de 100644 --- a/api/dify_graph/nodes/list_operator/entities.py +++ b/api/dify_graph/nodes/list_operator/entities.py @@ -3,7 +3,8 @@ from enum import StrEnum from pydantic import BaseModel, Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class FilterOperator(StrEnum): @@ -62,6 +63,7 @@ class ExtractConfig(BaseModel): class ListOperatorNodeData(BaseNodeData): + type: NodeType = NodeType.LIST_OPERATOR variable: Sequence[str] = Field(default_factory=list) filter_by: FilterBy order_by: OrderByConfig diff --git a/api/dify_graph/nodes/llm/entities.py b/api/dify_graph/nodes/llm/entities.py index 707ed8ece0..71728aa227 100644 --- a/api/dify_graph/nodes/llm/entities.py +++ b/api/dify_graph/nodes/llm/entities.py @@ -4,8 +4,9 @@ from typing import Any, Literal from pydantic import BaseModel, Field, field_validator from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode -from dify_graph.nodes.base import BaseNodeData from dify_graph.nodes.base.entities import VariableSelector @@ -59,6 +60,7 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): class LLMNodeData(BaseNodeData): + type: NodeType = NodeType.LLM model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_config: PromptConfig = Field(default_factory=PromptConfig) diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index 5e59c96cd6..b88ff404c0 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -21,6 +21,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.tools.signature import sign_upload_file from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( NodeType, SystemVariableKey, @@ -121,7 +122,7 @@ class LLMNode(Node[LLMNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, *, @@ -954,14 +955,11 @@ class LLMNode(Node[LLMNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: LLMNodeData, ) -> Mapping[str, Sequence[str]]: # graph_config is not used in this node type _ = graph_config # Explicitly mark as unused - # Create typed NodeData from dict - typed_node_data = LLMNodeData.model_validate(node_data) - - prompt_template = typed_node_data.prompt_template + prompt_template = node_data.prompt_template variable_selectors = [] if isinstance(prompt_template, list): for prompt in prompt_template: @@ -979,7 +977,7 @@ class LLMNode(Node[LLMNodeData]): for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - memory = typed_node_data.memory + memory = node_data.memory if memory and memory.query_prompt_template: query_variable_selectors = VariableTemplateParser( template=memory.query_prompt_template @@ -987,16 +985,16 @@ class LLMNode(Node[LLMNodeData]): for variable_selector in query_variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - if typed_node_data.context.enabled: - variable_mapping["#context#"] = typed_node_data.context.variable_selector + if node_data.context.enabled: + variable_mapping["#context#"] = node_data.context.variable_selector - if typed_node_data.vision.enabled: - variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector + if node_data.vision.enabled: + variable_mapping["#files#"] = node_data.vision.configs.variable_selector - if typed_node_data.memory: + if node_data.memory: variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY] - if typed_node_data.prompt_config: + if node_data.prompt_config: enable_jinja = False if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): @@ -1009,7 +1007,7 @@ class LLMNode(Node[LLMNodeData]): break if enable_jinja: - for variable_selector in typed_node_data.prompt_config.jinja2_variables or []: + for variable_selector in node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} diff --git a/api/dify_graph/nodes/loop/entities.py b/api/dify_graph/nodes/loop/entities.py index b4a8518048..8a3df5c234 100644 --- a/api/dify_graph/nodes/loop/entities.py +++ b/api/dify_graph/nodes/loop/entities.py @@ -3,7 +3,9 @@ from typing import Annotated, Any, Literal from pydantic import AfterValidator, BaseModel, Field, field_validator -from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType +from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState from dify_graph.utils.condition.entities import Condition from dify_graph.variables.types import SegmentType @@ -39,6 +41,7 @@ class LoopVariableData(BaseModel): class LoopNodeData(BaseLoopNodeData): + type: NodeType = NodeType.LOOP loop_count: int # Maximum number of loops break_conditions: list[Condition] # Conditions to break the loop logical_operator: Literal["and", "or"] @@ -58,7 +61,7 @@ class LoopStartNodeData(BaseNodeData): Loop Start Node Data. """ - pass + type: NodeType = NodeType.LOOP_START class LoopEndNodeData(BaseNodeData): @@ -66,7 +69,7 @@ class LoopEndNodeData(BaseNodeData): Loop End Node Data. """ - pass + type: NodeType = NodeType.LOOP_END class LoopState(BaseLoopState): diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py index 8279f0fc66..1a8774f445 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, cast +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import ( NodeExecutionType, NodeType, @@ -298,11 +299,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: LoopNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = LoopNodeData.model_validate(node_data) - variable_mapping = {} # Extract loop node IDs statically from graph_config @@ -318,16 +316,18 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): # variable selector to variable mapping try: # Get node class - from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + from dify_graph.nodes.node_mapping import get_node_type_classes_mapping - node_type = NodeType(sub_node_config.get("data", {}).get("type")) - if node_type not in NODE_TYPE_CLASSES_MAPPING: + typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) + node_type = typed_sub_node_config["data"].type + node_mapping = get_node_type_classes_mapping() + if node_type not in node_mapping: continue - node_version = sub_node_config.get("data", {}).get("version", "1") - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] + node_version = str(typed_sub_node_config["data"].version) + node_cls = node_mapping[node_type][node_version] sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=sub_node_config + graph_config=graph_config, config=typed_sub_node_config ) sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) except NotImplementedError: @@ -342,7 +342,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): variable_mapping.update(sub_node_variable_mapping) - for loop_variable in typed_node_data.loop_variables or []: + for loop_variable in node_data.loop_variables or []: if loop_variable.value_type == "variable": assert loop_variable.value is not None, "Loop variable value must be provided for variable type" # add loop variable to variable mapping diff --git a/api/dify_graph/nodes/node_mapping.py b/api/dify_graph/nodes/node_mapping.py index 8e5405f1aa..e0f5524a04 100644 --- a/api/dify_graph/nodes/node_mapping.py +++ b/api/dify_graph/nodes/node_mapping.py @@ -5,5 +5,24 @@ from dify_graph.nodes.base.node import Node LATEST_VERSION = "latest" -# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks dify_graph.nodes -NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() + +def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]: + """Return the live node registry after importing all `dify_graph.nodes` modules.""" + return Node.get_node_type_classes_mapping() + + +def resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + node_mapping = get_node_type_classes_mapping().get(node_type) + if not node_mapping: + raise ValueError(f"No class mapping found for node type: {node_type}") + + latest_node_class = node_mapping.get(LATEST_VERSION) + matched_node_class = node_mapping.get(node_version) + node_class = matched_node_class or latest_node_class + if not node_class: + raise ValueError(f"No latest version class found for node type: {node_type}") + return node_class + + +# Snapshot kept for compatibility with older tests; production paths should use the live helpers. +NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = get_node_type_classes_mapping() diff --git a/api/dify_graph/nodes/parameter_extractor/entities.py b/api/dify_graph/nodes/parameter_extractor/entities.py index 3b042710f9..8f8a278d5b 100644 --- a/api/dify_graph/nodes/parameter_extractor/entities.py +++ b/api/dify_graph/nodes/parameter_extractor/entities.py @@ -8,7 +8,8 @@ from pydantic import ( ) from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig from dify_graph.variables.types import SegmentType @@ -83,6 +84,7 @@ class ParameterExtractorNodeData(BaseNodeData): Parameter Extractor Node Data. """ + type: NodeType = NodeType.PARAMETER_EXTRACTOR model: ModelConfig query: list[str] parameters: list[ParameterConfig] diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index 1325a6a09a..68bd15db30 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -10,6 +10,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( NodeType, WorkflowNodeExecutionMetadataKey, @@ -106,7 +107,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -837,15 +838,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: ParameterExtractorNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = ParameterExtractorNodeData.model_validate(node_data) + _ = graph_config # Explicitly mark as unused + variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} - variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query} - - if typed_node_data.instruction: - selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction) + if node_data.instruction: + selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) for selector in selectors: variable_mapping[selector.variable] = selector.value_selector diff --git a/api/dify_graph/nodes/question_classifier/entities.py b/api/dify_graph/nodes/question_classifier/entities.py index 03e0a0ac53..77a6c70c28 100644 --- a/api/dify_graph/nodes/question_classifier/entities.py +++ b/api/dify_graph/nodes/question_classifier/entities.py @@ -1,7 +1,8 @@ from pydantic import BaseModel, Field from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.llm import ModelConfig, VisionConfig @@ -11,6 +12,7 @@ class ClassConfig(BaseModel): class QuestionClassifierNodeData(BaseNodeData): + type: NodeType = NodeType.QUESTION_CLASSIFIER query_variable_selector: list[str] model: ModelConfig classes: list[ClassConfig] diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index 443d216186..a61bca4ea9 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -7,6 +7,7 @@ from core.model_manager import ModelInstance from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( NodeExecutionType, NodeType, @@ -62,7 +63,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -251,16 +252,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: QuestionClassifierNodeData, ) -> Mapping[str, Sequence[str]]: # graph_config is not used in this node type - # Create typed NodeData from dict - typed_node_data = QuestionClassifierNodeData.model_validate(node_data) - - variable_mapping = {"query": typed_node_data.query_variable_selector} + variable_mapping = {"query": node_data.query_variable_selector} variable_selectors: list[VariableSelector] = [] - if typed_node_data.instruction: - variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction) + if node_data.instruction: + variable_template_parser = VariableTemplateParser(template=node_data.instruction) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = list(variable_selector.value_selector) diff --git a/api/dify_graph/nodes/start/entities.py b/api/dify_graph/nodes/start/entities.py index 0df832740e..cbf7348360 100644 --- a/api/dify_graph/nodes/start/entities.py +++ b/api/dify_graph/nodes/start/entities.py @@ -2,7 +2,8 @@ from collections.abc import Sequence from pydantic import Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.variables.input_entities import VariableEntity @@ -11,4 +12,5 @@ class StartNodeData(BaseNodeData): Start Node Data """ + type: NodeType = NodeType.START variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/template_transform/entities.py b/api/dify_graph/nodes/template_transform/entities.py index 123fd41f81..2a79a82870 100644 --- a/api/dify_graph/nodes/template_transform/entities.py +++ b/api/dify_graph/nodes/template_transform/entities.py @@ -1,4 +1,5 @@ -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.base.entities import VariableSelector @@ -7,5 +8,6 @@ class TemplateTransformNodeData(BaseNodeData): Template Transform Node Data. """ + type: NodeType = NodeType.TEMPLATE_TRANSFORM variables: list[VariableSelector] template: str diff --git a/api/dify_graph/nodes/template_transform/template_transform_node.py b/api/dify_graph/nodes/template_transform/template_transform_node.py index 367442e997..9dfb535342 100644 --- a/api/dify_graph/nodes/template_transform/template_transform_node.py +++ b/api/dify_graph/nodes/template_transform/template_transform_node.py @@ -1,6 +1,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -25,7 +26,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -86,12 +87,9 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any] + cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = TemplateTransformNodeData.model_validate(node_data) - return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in typed_node_data.variables + for variable_selector in node_data.variables } diff --git a/api/dify_graph/nodes/tool/entities.py b/api/dify_graph/nodes/tool/entities.py index f15dabdeeb..4ba8c16e85 100644 --- a/api/dify_graph/nodes/tool/entities.py +++ b/api/dify_graph/nodes/tool/entities.py @@ -4,7 +4,8 @@ from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.nodes.base.entities import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class ToolEntity(BaseModel): @@ -32,6 +33,8 @@ class ToolEntity(BaseModel): class ToolNodeData(BaseNodeData, ToolEntity): + type: NodeType = NodeType.TOOL + class ToolInput(BaseModel): # TODO: check this type value: Union[Any, list[str]] diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index a6e0b710f1..ec7386981e 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -7,6 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( NodeType, SystemVariableKey, @@ -46,7 +47,7 @@ class ToolNode(Node[ToolNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -64,6 +65,10 @@ class ToolNode(Node[ToolNodeData]): def version(cls) -> str: return "1" + def populate_start_event(self, event) -> None: + event.provider_id = self.node_data.provider_id + event.provider_type = self.node_data.provider_type + def _run(self) -> Generator[NodeEventBase, None, None]: """ Run the tool node @@ -484,7 +489,7 @@ class ToolNode(Node[ToolNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: ToolNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -493,9 +498,8 @@ class ToolNode(Node[ToolNodeData]): :param node_data: node data :return: """ - # Create typed NodeData from dict - typed_node_data = ToolNodeData.model_validate(node_data) - + _ = graph_config # Explicitly mark as unused + typed_node_data = node_data result = {} for parameter_name in typed_node_data.tool_parameters: input = typed_node_data.tool_parameters[parameter_name] diff --git a/api/dify_graph/nodes/trigger_plugin/entities.py b/api/dify_graph/nodes/trigger_plugin/entities.py index 75d10ecaa4..33a61c9bc8 100644 --- a/api/dify_graph/nodes/trigger_plugin/entities.py +++ b/api/dify_graph/nodes/trigger_plugin/entities.py @@ -4,13 +4,16 @@ from typing import Any, Literal, Union from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.entities.entities import EventParameter -from dify_graph.nodes.base.entities import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.trigger_plugin.exc import TriggerEventParameterError class TriggerEventNodeData(BaseNodeData): """Plugin trigger node data""" + type: NodeType = NodeType.TRIGGER_PLUGIN + class TriggerEventInput(BaseModel): value: Union[Any, list[str]] type: Literal["mixed", "variable", "constant"] @@ -38,8 +41,6 @@ class TriggerEventNodeData(BaseNodeData): raise ValueError("value must be a string, int, float, bool or dict") return type - title: str - desc: str | None = None plugin_id: str = Field(..., description="Plugin ID") provider_id: str = Field(..., description="Provider ID") event_name: str = Field(..., description="Event name") diff --git a/api/dify_graph/nodes/trigger_plugin/trigger_event_node.py b/api/dify_graph/nodes/trigger_plugin/trigger_event_node.py index b4f1116f7e..536ba96dec 100644 --- a/api/dify_graph/nodes/trigger_plugin/trigger_event_node.py +++ b/api/dify_graph/nodes/trigger_plugin/trigger_event_node.py @@ -32,6 +32,9 @@ class TriggerEventNode(Node[TriggerEventNodeData]): def version(cls) -> str: return "1" + def populate_start_event(self, event) -> None: + event.provider_id = self.node_data.provider_id + def _run(self) -> NodeRunResult: """ Run the plugin trigger node. diff --git a/api/dify_graph/nodes/trigger_schedule/entities.py b/api/dify_graph/nodes/trigger_schedule/entities.py index 6daadc7666..2b0edcabba 100644 --- a/api/dify_graph/nodes/trigger_schedule/entities.py +++ b/api/dify_graph/nodes/trigger_schedule/entities.py @@ -2,7 +2,8 @@ from typing import Literal, Union from pydantic import BaseModel, Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): @@ -10,6 +11,7 @@ class TriggerScheduleNodeData(BaseNodeData): Trigger Schedule Node Data """ + type: NodeType = NodeType.TRIGGER_SCHEDULE mode: str = Field(default="visual", description="Schedule mode: visual or cron") frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly") cron_expression: str | None = Field(default=None, description="Cron expression for cron mode") diff --git a/api/dify_graph/nodes/trigger_schedule/exc.py b/api/dify_graph/nodes/trigger_schedule/exc.py index caea6241e4..336d64d58f 100644 --- a/api/dify_graph/nodes/trigger_schedule/exc.py +++ b/api/dify_graph/nodes/trigger_schedule/exc.py @@ -1,4 +1,4 @@ -from dify_graph.nodes.base.exc import BaseNodeError +from dify_graph.entities.exc import BaseNodeError class ScheduleNodeError(BaseNodeError): diff --git a/api/dify_graph/nodes/trigger_webhook/entities.py b/api/dify_graph/nodes/trigger_webhook/entities.py index fa36aeabd3..a4f8745e71 100644 --- a/api/dify_graph/nodes/trigger_webhook/entities.py +++ b/api/dify_graph/nodes/trigger_webhook/entities.py @@ -1,10 +1,41 @@ from collections.abc import Sequence from enum import StrEnum -from typing import Literal from pydantic import BaseModel, Field, field_validator -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType +from dify_graph.variables.types import SegmentType + +_WEBHOOK_HEADER_ALLOWED_TYPES = frozenset( + { + SegmentType.STRING, + } +) + +_WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES = frozenset( + { + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.BOOLEAN, + } +) + +_WEBHOOK_PARAMETER_ALLOWED_TYPES = _WEBHOOK_HEADER_ALLOWED_TYPES | _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES + +_WEBHOOK_BODY_ALLOWED_TYPES = frozenset( + { + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.BOOLEAN, + SegmentType.OBJECT, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_BOOLEAN, + SegmentType.ARRAY_OBJECT, + SegmentType.FILE, + } +) class Method(StrEnum): @@ -25,29 +56,34 @@ class ContentType(StrEnum): class WebhookParameter(BaseModel): - """Parameter definition for headers, query params, or body.""" + """Parameter definition for headers or query params.""" name: str + type: SegmentType = SegmentType.STRING required: bool = False + @field_validator("type", mode="after") + @classmethod + def validate_type(cls, v: SegmentType) -> SegmentType: + if v not in _WEBHOOK_PARAMETER_ALLOWED_TYPES: + raise ValueError(f"Unsupported webhook parameter type: {v}") + return v + class WebhookBodyParameter(BaseModel): """Body parameter with type information.""" name: str - type: Literal[ - "string", - "number", - "boolean", - "object", - "array[string]", - "array[number]", - "array[boolean]", - "array[object]", - "file", - ] = "string" + type: SegmentType = SegmentType.STRING required: bool = False + @field_validator("type", mode="after") + @classmethod + def validate_type(cls, v: SegmentType) -> SegmentType: + if v not in _WEBHOOK_BODY_ALLOWED_TYPES: + raise ValueError(f"Unsupported webhook body parameter type: {v}") + return v + class WebhookData(BaseNodeData): """ @@ -57,6 +93,7 @@ class WebhookData(BaseNodeData): class SyncMode(StrEnum): SYNC = "async" # only support + type: NodeType = NodeType.TRIGGER_WEBHOOK method: Method = Method.GET content_type: ContentType = Field(default=ContentType.JSON) headers: Sequence[WebhookParameter] = Field(default_factory=list) @@ -71,6 +108,22 @@ class WebhookData(BaseNodeData): return v.lower() return v + @field_validator("headers", mode="after") + @classmethod + def validate_header_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]: + for param in v: + if param.type not in _WEBHOOK_HEADER_ALLOWED_TYPES: + raise ValueError(f"Unsupported webhook header parameter type: {param.type}") + return v + + @field_validator("params", mode="after") + @classmethod + def validate_query_parameter_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]: + for param in v: + if param.type not in _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES: + raise ValueError(f"Unsupported webhook query parameter type: {param.type}") + return v + status_code: int = 200 # Expected status code for response response_body: str = "" # Template for response body diff --git a/api/dify_graph/nodes/trigger_webhook/exc.py b/api/dify_graph/nodes/trigger_webhook/exc.py index 853b2456c5..4d87f2a069 100644 --- a/api/dify_graph/nodes/trigger_webhook/exc.py +++ b/api/dify_graph/nodes/trigger_webhook/exc.py @@ -1,4 +1,4 @@ -from dify_graph.nodes.base.exc import BaseNodeError +from dify_graph.entities.exc import BaseNodeError class WebhookNodeError(BaseNodeError): diff --git a/api/dify_graph/nodes/trigger_webhook/node.py b/api/dify_graph/nodes/trigger_webhook/node.py index e466541908..413eda5272 100644 --- a/api/dify_graph/nodes/trigger_webhook/node.py +++ b/api/dify_graph/nodes/trigger_webhook/node.py @@ -152,7 +152,7 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = raw_data continue - if param_type == "file": + if param_type == SegmentType.FILE: # Get File object (already processed by webhook controller) files = webhook_data.get("files", {}) if files and isinstance(files, dict): diff --git a/api/dify_graph/nodes/variable_aggregator/entities.py b/api/dify_graph/nodes/variable_aggregator/entities.py index 5f7c1dbe93..fec4c4474c 100644 --- a/api/dify_graph/nodes/variable_aggregator/entities.py +++ b/api/dify_graph/nodes/variable_aggregator/entities.py @@ -1,6 +1,7 @@ from pydantic import BaseModel -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.variables.types import SegmentType @@ -28,6 +29,7 @@ class VariableAggregatorNodeData(BaseNodeData): Variable Aggregator Node Data. """ + type: NodeType = NodeType.VARIABLE_AGGREGATOR output_type: str variables: list[list[str]] advanced_settings: AdvancedSettings | None = None diff --git a/api/dify_graph/nodes/variable_assigner/v1/node.py b/api/dify_graph/nodes/variable_assigner/v1/node.py index 1aa7042b02..1d17b981ba 100644 --- a/api/dify_graph/nodes/variable_assigner/v1/node.py +++ b/api/dify_graph/nodes/variable_assigner/v1/node.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -22,7 +23,7 @@ class VariableAssignerNode(Node[VariableAssignerData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ): @@ -52,21 +53,18 @@ class VariableAssignerNode(Node[VariableAssignerData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: VariableAssignerData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = VariableAssignerData.model_validate(node_data) - mapping = {} - assigned_variable_node_id = typed_node_data.assigned_variable_selector[0] + assigned_variable_node_id = node_data.assigned_variable_selector[0] if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: - selector_key = ".".join(typed_node_data.assigned_variable_selector) + selector_key = ".".join(node_data.assigned_variable_selector) key = f"{node_id}.#{selector_key}#" - mapping[key] = typed_node_data.assigned_variable_selector + mapping[key] = node_data.assigned_variable_selector - selector_key = ".".join(typed_node_data.input_variable_selector) + selector_key = ".".join(node_data.input_variable_selector) key = f"{node_id}.#{selector_key}#" - mapping[key] = typed_node_data.input_variable_selector + mapping[key] = node_data.input_variable_selector return mapping def _run(self) -> NodeRunResult: diff --git a/api/dify_graph/nodes/variable_assigner/v1/node_data.py b/api/dify_graph/nodes/variable_assigner/v1/node_data.py index 11e8f93f35..a75a2397ba 100644 --- a/api/dify_graph/nodes/variable_assigner/v1/node_data.py +++ b/api/dify_graph/nodes/variable_assigner/v1/node_data.py @@ -1,7 +1,8 @@ from collections.abc import Sequence from enum import StrEnum -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class WriteMode(StrEnum): @@ -11,6 +12,7 @@ class WriteMode(StrEnum): class VariableAssignerData(BaseNodeData): + type: NodeType = NodeType.VARIABLE_ASSIGNER assigned_variable_selector: Sequence[str] write_mode: WriteMode input_variable_selector: Sequence[str] diff --git a/api/dify_graph/nodes/variable_assigner/v2/entities.py b/api/dify_graph/nodes/variable_assigner/v2/entities.py index 5f9211d600..ca3a94b777 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/entities.py +++ b/api/dify_graph/nodes/variable_assigner/v2/entities.py @@ -3,7 +3,8 @@ from typing import Any from pydantic import BaseModel, Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from .enums import InputType, Operation @@ -22,5 +23,6 @@ class VariableOperationItem(BaseModel): class VariableAssignerNodeData(BaseNodeData): + type: NodeType = NodeType.VARIABLE_ASSIGNER version: str = "2" items: Sequence[VariableOperationItem] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/variable_assigner/v2/node.py b/api/dify_graph/nodes/variable_assigner/v2/node.py index 7753382cd0..771609ceb6 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/node.py +++ b/api/dify_graph/nodes/variable_assigner/v2/node.py @@ -3,6 +3,7 @@ from collections.abc import Mapping, MutableMapping, Sequence from typing import TYPE_CHECKING, Any from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -56,7 +57,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ): @@ -94,13 +95,10 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: VariableAssignerNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = VariableAssignerNodeData.model_validate(node_data) - var_mapping: dict[str, Sequence[str]] = {} - for item in typed_node_data.items: + for item in node_data.items: _target_mapping_from_item(var_mapping, node_id, item) _source_mapping_from_item(var_mapping, node_id, item) return var_mapping diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 7ee4638e77..a94d75ec76 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -17,7 +17,8 @@ from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value -from models.workflow import WorkflowNodeExecutionModel +from models.enums import CreatorUserRole +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository logger = logging.getLogger(__name__) @@ -47,12 +48,28 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode model.tenant_id = data.get("tenant_id") or "" model.app_id = data.get("app_id") or "" model.workflow_id = data.get("workflow_id") or "" - model.triggered_from = data.get("triggered_from") or "" + triggered_from_val = data.get("triggered_from") + try: + model.triggered_from = ( + WorkflowNodeExecutionTriggeredFrom(str(triggered_from_val)) + if triggered_from_val + else WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + ) + except ValueError: + logger.warning("Invalid triggered_from value: %s, falling back to WORKFLOW_RUN", triggered_from_val) + model.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN model.node_id = data.get("node_id") or "" model.node_type = data.get("node_type") or "" model.status = data.get("status") or "running" # Default status if missing model.title = data.get("title") or "" - model.created_by_role = data.get("created_by_role") or "" + created_by_role_val = data.get("created_by_role") + try: + model.created_by_role = ( + CreatorUserRole(str(created_by_role_val)) if created_by_role_val else CreatorUserRole.ACCOUNT + ) + except ValueError: + logger.warning("Invalid created_by_role value: %s, falling back to ACCOUNT", created_by_role_val) + model.created_by_role = CreatorUserRole.ACCOUNT model.created_by = data.get("created_by") or "" model.index = safe_int(data.get("index", 0)) diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 14382ed876..bdfc81bd1c 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -22,12 +22,13 @@ from typing import Any, cast from sqlalchemy.orm import sessionmaker +from dify_graph.enums import WorkflowExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models.enums import WorkflowRunTriggeredFrom -from models.workflow import WorkflowRun +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.workflow import WorkflowRun, WorkflowType from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.types import ( AverageInteractionStats, @@ -59,11 +60,37 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: model.tenant_id = data.get("tenant_id") or "" model.app_id = data.get("app_id") or "" model.workflow_id = data.get("workflow_id") or "" - model.type = data.get("type") or "" - model.triggered_from = data.get("triggered_from") or "" + type_val = data.get("type") + try: + model.type = WorkflowType(str(type_val)) if type_val else WorkflowType.WORKFLOW + except ValueError: + logger.warning("Invalid type value: %s, falling back to WORKFLOW", type_val) + model.type = WorkflowType.WORKFLOW + triggered_from_val = data.get("triggered_from") + try: + model.triggered_from = ( + WorkflowRunTriggeredFrom(str(triggered_from_val)) + if triggered_from_val + else WorkflowRunTriggeredFrom.APP_RUN + ) + except ValueError: + logger.warning("Invalid triggered_from value: %s, falling back to APP_RUN", triggered_from_val) + model.triggered_from = WorkflowRunTriggeredFrom.APP_RUN model.version = data.get("version") or "" - model.status = data.get("status") or "running" # Default status if missing - model.created_by_role = data.get("created_by_role") or "" + status_val = data.get("status") + try: + model.status = WorkflowExecutionStatus(str(status_val)) if status_val else WorkflowExecutionStatus.RUNNING + except ValueError: + logger.warning("Invalid status value: %s, falling back to RUNNING", status_val) + model.status = WorkflowExecutionStatus.RUNNING + created_by_role_val = data.get("created_by_role") + try: + model.created_by_role = ( + CreatorUserRole(str(created_by_role_val)) if created_by_role_val else CreatorUserRole.ACCOUNT + ) + except ValueError: + logger.warning("Invalid created_by_role value: %s, falling back to ACCOUNT", created_by_role_val) + model.created_by_role = CreatorUserRole.ACCOUNT model.created_by = data.get("created_by") or "" model.total_tokens = safe_int(data.get("total_tokens", 0)) diff --git a/api/extensions/otel/runtime.py b/api/extensions/otel/runtime.py index a9ff0eed22..b1c703f944 100644 --- a/api/extensions/otel/runtime.py +++ b/api/extensions/otel/runtime.py @@ -7,7 +7,7 @@ from celery.signals import worker_init from flask_login import user_loaded_from_request, user_logged_in from opentelemetry import trace from opentelemetry.propagate import set_global_textmap -from opentelemetry.propagators.b3 import B3Format +from opentelemetry.propagators.b3 import B3MultiFormat from opentelemetry.propagators.composite import CompositePropagator from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator @@ -24,7 +24,7 @@ def setup_context_propagation() -> None: CompositePropagator( [ TraceContextTextMapPropagator(), - B3Format(), + B3MultiFormat(), ] ) ) diff --git a/api/models/account.py b/api/models/account.py index f7a9c20026..1a43c9ca17 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -8,12 +8,12 @@ from uuid import uuid4 import sqlalchemy as sa from flask_login import UserMixin from sqlalchemy import DateTime, String, func, select -from sqlalchemy.orm import Mapped, Session, mapped_column, validates +from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import deprecated from .base import TypeBase from .engine import db -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID class TenantAccountRole(enum.StrEnum): @@ -104,7 +104,9 @@ class Account(UserMixin, TypeBase): last_active_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False ) - status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'"), default="active") + status: Mapped[AccountStatus] = mapped_column( + EnumText(AccountStatus, length=16), server_default=sa.text("'active'"), default=AccountStatus.ACTIVE + ) initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False @@ -116,12 +118,6 @@ class Account(UserMixin, TypeBase): role: TenantAccountRole | None = field(default=None, init=False) _current_tenant: "Tenant | None" = field(default=None, init=False) - @validates("status") - def _normalize_status(self, _key: str, value: str | AccountStatus) -> str: - if isinstance(value, AccountStatus): - return value.value - return value - @property def is_password_set(self): return self.password is not None @@ -177,8 +173,7 @@ class Account(UserMixin, TypeBase): return self.role def get_status(self) -> AccountStatus: - status_str = self.status - return AccountStatus(status_str) + return self.status @classmethod def get_by_openid(cls, provider: str, open_id: str): @@ -249,7 +244,9 @@ class Tenant(TypeBase): name: Mapped[str] = mapped_column(String(255)) encrypt_public_key: Mapped[str | None] = mapped_column(LongText, default=None) plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'"), default="basic") - status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'"), default="normal") + status: Mapped[TenantStatus] = mapped_column( + EnumText(TenantStatus, length=255), server_default=sa.text("'normal'"), default=TenantStatus.NORMAL + ) custom_config: Mapped[str | None] = mapped_column(LongText, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False @@ -291,7 +288,9 @@ class TenantAccountJoin(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID) current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False) - role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal") + role: Mapped[TenantAccountRole] = mapped_column( + EnumText(TenantAccountRole, length=16), server_default="normal", default=TenantAccountRole.NORMAL + ) invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False @@ -324,6 +323,11 @@ class AccountIntegrate(TypeBase): ) +class InvitationCodeStatus(enum.StrEnum): + UNUSED = "unused" + USED = "used" + + class InvitationCode(TypeBase): __tablename__ = "invitation_codes" __table_args__ = ( @@ -335,7 +339,11 @@ class InvitationCode(TypeBase): id: Mapped[int] = mapped_column(sa.Integer, init=False) batch: Mapped[str] = mapped_column(String(255)) code: Mapped[str] = mapped_column(String(32)) - status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'"), default="unused") + status: Mapped[InvitationCodeStatus] = mapped_column( + EnumText(InvitationCodeStatus, length=16), + server_default=sa.text("'unused'"), + default=InvitationCodeStatus.UNUSED, + ) used_at: Mapped[datetime | None] = mapped_column(DateTime, default=None) used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID, default=None) used_by_account_id: Mapped[str | None] = mapped_column(StringUUID, default=None) @@ -367,10 +375,13 @@ class TenantPluginPermission(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) install_permission: Mapped[InstallPermission] = mapped_column( - String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE + EnumText(InstallPermission, length=16), + nullable=False, + server_default="everyone", + default=InstallPermission.EVERYONE, ) debug_permission: Mapped[DebugPermission] = mapped_column( - String(16), nullable=False, server_default="noone", default=DebugPermission.NOBODY + EnumText(DebugPermission, length=16), nullable=False, server_default="noone", default=DebugPermission.NOBODY ) @@ -396,10 +407,13 @@ class TenantPluginAutoUpgradeStrategy(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) strategy_setting: Mapped[StrategySetting] = mapped_column( - String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY + EnumText(StrategySetting, length=16), + nullable=False, + server_default="fix_only", + default=StrategySetting.FIX_ONLY, ) upgrade_mode: Mapped[UpgradeMode] = mapped_column( - String(16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE + EnumText(UpgradeMode, length=16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE ) exclude_plugins: Mapped[list[str]] = mapped_column(sa.JSON, nullable=False, default_factory=list) include_plugins: Mapped[list[str]] = mapped_column(sa.JSON, nullable=False, default_factory=list) diff --git a/api/models/dataset.py b/api/models/dataset.py index 4ef39fcde1..b3fa11a58c 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -30,8 +30,9 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode, from .account import Account from .base import Base, TypeBase from .engine import db +from .enums import CreatorUserRole from .model import App, Tag, TagBinding, UploadFile -from .types import AdjustedJSON, BinaryData, LongText, StringUUID, adjusted_json_index +from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index logger = logging.getLogger(__name__) @@ -59,7 +60,11 @@ class Dataset(Base): name: Mapped[str] = mapped_column(String(255)) description = mapped_column(LongText, nullable=True) provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'")) - permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'")) + permission: Mapped[DatasetPermissionEnum] = mapped_column( + EnumText(DatasetPermissionEnum, length=255), + server_default=sa.text("'only_me'"), + default=DatasetPermissionEnum.ONLY_ME, + ) data_source_type = mapped_column(String(255)) indexing_technique: Mapped[str | None] = mapped_column(String(255)) index_struct = mapped_column(LongText, nullable=True) @@ -1003,7 +1008,7 @@ class DatasetQuery(TypeBase): content: Mapped[str] = mapped_column(LongText, nullable=False) source: Mapped[str] = mapped_column(String(255), nullable=False) source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False diff --git a/api/models/enums.py b/api/models/enums.py index ed6236209f..66e3e4b332 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -72,3 +72,23 @@ class AppTriggerType(StrEnum): # for backward compatibility UNKNOWN = "unknown" + + +class AppStatus(StrEnum): + """App Status Enum""" + + NORMAL = "normal" + + +class AppMCPServerStatus(StrEnum): + """AppMCPServer Status Enum""" + + NORMAL = "normal" + ACTIVE = "active" + INACTIVE = "inactive" + + +class ConversationStatus(StrEnum): + """Conversation Status Enum""" + + NORMAL = "normal" diff --git a/api/models/model.py b/api/models/model.py index ed0614c195..2e747df2c7 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -29,9 +29,9 @@ from libs.uuid_utils import uuidv7 from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string from .engine import db -from .enums import CreatorUserRole +from .enums import AppMCPServerStatus, AppStatus, ConversationStatus, CreatorUserRole, MessageStatus from .provider_ids import GenericProviderID -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID if TYPE_CHECKING: from .workflow import Workflow @@ -337,13 +337,15 @@ class App(Base): tenant_id: Mapped[str] = mapped_column(StringUUID) name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(LongText, default=sa.text("''")) - mode: Mapped[str] = mapped_column(String(255)) - icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link + mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255)) + icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255)) icon = mapped_column(String(255)) icon_background: Mapped[str | None] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'")) + status: Mapped[AppStatus] = mapped_column( + EnumText(AppStatus, length=255), server_default=sa.text("'normal'"), default=AppStatus.NORMAL + ) enable_site: Mapped[bool] = mapped_column(sa.Boolean) enable_api: Mapped[bool] = mapped_column(sa.Boolean) api_rpm: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0")) @@ -1000,14 +1002,16 @@ class Conversation(Base): model_provider = mapped_column(String(255), nullable=True) override_model_configs = mapped_column(LongText) model_id = mapped_column(String(255), nullable=True) - mode: Mapped[str] = mapped_column(String(255)) + mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255)) name: Mapped[str] = mapped_column(String(255), nullable=False) summary = mapped_column(LongText) _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) introduction = mapped_column(LongText) system_instruction = mapped_column(LongText) system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) - status: Mapped[str] = mapped_column(String(255), nullable=False) + status: Mapped[ConversationStatus] = mapped_column( + EnumText(ConversationStatus, length=255), nullable=False, default=ConversationStatus.NORMAL + ) # The `invoke_from` records how the conversation is created. # @@ -1351,7 +1355,12 @@ class Message(Base): provider_response_latency: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric(10, 7)) currency: Mapped[str] = mapped_column(String(255), nullable=False) - status: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'")) + status: Mapped[MessageStatus] = mapped_column( + EnumText(MessageStatus, length=255), + nullable=False, + server_default=sa.text("'normal'"), + default=MessageStatus.NORMAL, + ) error: Mapped[str | None] = mapped_column(LongText) message_metadata: Mapped[str | None] = mapped_column(LongText) invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True) @@ -1364,7 +1373,7 @@ class Message(Base): ) agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) - app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True) + app_mode: Mapped[AppMode | None] = mapped_column(EnumText(AppMode, length=255), nullable=True) @property def inputs(self) -> dict[str, Any]: @@ -1766,8 +1775,10 @@ class MessageFile(TypeBase): ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) - transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False) - created_by_role: Mapped[CreatorUserRole] = mapped_column(String(255), nullable=False) + transfer_method: Mapped[FileTransferMethod] = mapped_column( + EnumText(FileTransferMethod, length=255), nullable=False + ) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None) url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) @@ -1976,7 +1987,9 @@ class AppMCPServer(TypeBase): name: Mapped[str] = mapped_column(String(255), nullable=False) description: Mapped[str] = mapped_column(String(255), nullable=False) server_code: Mapped[str] = mapped_column(String(255), nullable=False) - status: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'")) + status: Mapped[AppMCPServerStatus] = mapped_column( + EnumText(AppMCPServerStatus, length=255), nullable=False, server_default=sa.text("'normal'") + ) parameters: Mapped[str] = mapped_column(LongText, nullable=False) created_at: Mapped[datetime] = mapped_column( @@ -2015,7 +2028,7 @@ class Site(Base): id = mapped_column(StringUUID, default=lambda: str(uuid4())) app_id = mapped_column(StringUUID, nullable=False) title: Mapped[str] = mapped_column(String(255), nullable=False) - icon_type = mapped_column(String(255), nullable=True) + icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255), nullable=True) icon = mapped_column(String(255)) icon_background = mapped_column(String(255)) description = mapped_column(LongText) @@ -2030,7 +2043,9 @@ class Site(Base): customize_domain = mapped_column(String(255)) customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False) prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'")) + status: Mapped[AppStatus] = mapped_column( + EnumText(AppStatus, length=255), nullable=False, server_default=sa.text("'normal'"), default=AppStatus.NORMAL + ) created_by = mapped_column(StringUUID, nullable=True) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -2110,7 +2125,12 @@ class UploadFile(Base): # The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`. # Its value is derived from the `CreatorUserRole` enumeration. - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'account'")) + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), + nullable=False, + server_default=sa.text("'account'"), + default=CreatorUserRole.ACCOUNT, + ) # The `created_by` field stores the ID of the entity that created this upload file. # @@ -2163,7 +2183,7 @@ class UploadFile(Base): self.size = size self.extension = extension self.mime_type = mime_type - self.created_by_role = created_by_role.value + self.created_by_role = created_by_role self.created_by = created_by self.created_at = created_at self.used = used @@ -2226,7 +2246,7 @@ class MessageAgentThought(TypeBase): ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) diff --git a/api/models/provider.py b/api/models/provider.py index 6175a3ae88..18a0fe92c8 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -13,7 +13,7 @@ from libs.uuid_utils import uuidv7 from .base import TypeBase from .engine import db -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID class ProviderType(StrEnum): @@ -69,8 +69,8 @@ class Provider(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) - provider_type: Mapped[str] = mapped_column( - String(40), nullable=False, server_default=text("'custom'"), default="custom" + provider_type: Mapped[ProviderType] = mapped_column( + EnumText(ProviderType, length=40), nullable=False, server_default=text("'custom'"), default=ProviderType.CUSTOM ) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"), default=False) last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False) diff --git a/api/models/trigger.py b/api/models/trigger.py index 209345eb84..43d7fc5b24 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -227,7 +227,7 @@ class WorkflowTriggerLog(TypeBase): queue_name: Mapped[str] = mapped_column(String(100), nullable=False) celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(String(255), nullable=False) retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) elapsed_time: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None) diff --git a/api/models/web.py b/api/models/web.py index 5f6a7b40bf..a1cc11c375 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -2,13 +2,14 @@ from datetime import datetime from uuid import uuid4 import sqlalchemy as sa -from sqlalchemy import DateTime, String, func +from sqlalchemy import DateTime, func from sqlalchemy.orm import Mapped, mapped_column from .base import TypeBase from .engine import db +from .enums import CreatorUserRole from .model import Message -from .types import StringUUID +from .types import EnumText, StringUUID class SavedMessage(TypeBase): @@ -24,7 +25,9 @@ class SavedMessage(TypeBase): ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'end_user'")) + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), nullable=False, server_default=sa.text("'end_user'") + ) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, @@ -50,8 +53,8 @@ class PinnedConversation(TypeBase): ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) - created_by_role: Mapped[str] = mapped_column( - String(255), + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), nullable=False, server_default=sa.text("'end_user'"), ) diff --git a/api/models/workflow.py b/api/models/workflow.py index d728ed83bc..8c62292079 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -53,7 +53,7 @@ from libs import helper from .account import Account from .base import Base, DefaultFieldsMixin, TypeBase from .engine import db -from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType +from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom from .types import EnumText, LongText, StringUUID logger = logging.getLogger(__name__) @@ -141,7 +141,7 @@ class Workflow(Base): # bug id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255), nullable=False) version: Mapped[str] = mapped_column(String(255), nullable=False) marked_name: Mapped[str] = mapped_column(String(255), default="", server_default="") marked_comment: Mapped[str] = mapped_column(String(255), default="", server_default="") @@ -188,7 +188,7 @@ class Workflow(Base): # bug workflow.id = str(uuid4()) workflow.tenant_id = tenant_id workflow.app_id = app_id - workflow.type = type + workflow.type = WorkflowType(type) workflow.version = version workflow.graph = graph workflow.features = features @@ -233,8 +233,11 @@ class Workflow(Base): # bug def get_node_config_by_id(self, node_id: str) -> NodeConfigDict: """Extract a node configuration from the workflow graph by node ID. - A node configuration is a dictionary containing the node's properties, including - the node's id, title, and its data as a dict. + + A node configuration includes the node id and a typed `BaseNodeData` for `data`. + `BaseNodeData` keeps a dict-like `get`/`__getitem__` compatibility layer backed by + model fields plus Pydantic extra storage for legacy consumers, but callers should + prefer attribute access. """ workflow_graph = self.graph_dict @@ -252,12 +255,9 @@ class Workflow(Base): # bug return NodeConfigDictAdapter.validate_python(node_config) @staticmethod - def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType: + def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: """Extract type of a node from the node configuration returned by `get_node_config_by_id`.""" - node_config_data = node_config.get("data", {}) - # Get node class - node_type = NodeType(node_config_data.get("type")) - return node_type + return node_config["data"].type @staticmethod def get_enclosing_node_type_and_id( @@ -608,8 +608,8 @@ class WorkflowRun(Base): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - type: Mapped[str] = mapped_column(String(255)) - triggered_from: Mapped[str] = mapped_column(String(255)) + type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255)) + triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(EnumText(WorkflowRunTriggeredFrom, length=255)) version: Mapped[str] = mapped_column(String(255)) graph: Mapped[str | None] = mapped_column(LongText) inputs: Mapped[str | None] = mapped_column(LongText) @@ -830,7 +830,9 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - triggered_from: Mapped[str] = mapped_column(String(255)) + triggered_from: Mapped[WorkflowNodeExecutionTriggeredFrom] = mapped_column( + EnumText(WorkflowNodeExecutionTriggeredFrom, length=255) + ) workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) index: Mapped[int] = mapped_column(sa.Integer) predecessor_node_id: Mapped[str | None] = mapped_column(String(255)) @@ -846,7 +848,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) execution_metadata: Mapped[str | None] = mapped_column(LongText) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) - created_by_role: Mapped[str] = mapped_column(String(255)) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255)) created_by: Mapped[str] = mapped_column(StringUUID) finished_at: Mapped[datetime | None] = mapped_column(DateTime) @@ -1130,7 +1132,7 @@ class WorkflowAppLog(TypeBase): workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) created_from: Mapped[str] = mapped_column(String(255), nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False @@ -1204,7 +1206,7 @@ class WorkflowArchiveLog(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) @@ -1213,7 +1215,9 @@ class WorkflowArchiveLog(TypeBase): run_version: Mapped[str] = mapped_column(String(255), nullable=False) run_status: Mapped[str] = mapped_column(String(255), nullable=False) - run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False) + run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column( + EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False + ) run_error: Mapped[str | None] = mapped_column(LongText, nullable=True) run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) diff --git a/api/pyproject.toml b/api/pyproject.toml index efe219e33a..64df4d1e77 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -53,7 +53,7 @@ dependencies = [ "opentelemetry-instrumentation-httpx==0.49b0", "opentelemetry-instrumentation-redis==0.49b0", "opentelemetry-instrumentation-sqlalchemy==0.49b0", - "opentelemetry-propagator-b3==1.28.0", + "opentelemetry-propagator-b3==1.40.0", "opentelemetry-proto==1.28.0", "opentelemetry-sdk==1.28.0", "opentelemetry-semantic-conventions==0.49b0", @@ -65,7 +65,7 @@ dependencies = [ "pydantic~=2.12.5", "pydantic-extra-types~=2.11.0", "pydantic-settings~=2.13.1", - "pyjwt~=2.11.0", + "pyjwt~=2.12.0", "pypdfium2==5.2.0", "python-docx~=1.2.0", "python-dotenv==1.0.1", @@ -109,46 +109,46 @@ package = false # Required for development and running tests ############################################################ dev = [ - "coverage~=7.2.4", - "dotenv-linter~=0.5.0", - "faker~=38.2.0", + "coverage~=7.13.4", + "dotenv-linter~=0.7.0", + "faker~=40.8.0", "lxml-stubs~=0.5.1", "basedpyright~=1.38.2", - "ruff~=0.14.0", - "pytest~=8.3.2", - "pytest-benchmark~=4.0.0", - "pytest-cov~=4.1.0", + "ruff~=0.15.5", + "pytest~=9.0.2", + "pytest-benchmark~=5.2.3", + "pytest-cov~=7.0.0", "pytest-env~=1.1.3", - "pytest-mock~=3.14.0", + "pytest-mock~=3.15.1", "testcontainers~=4.13.2", "types-aiofiles~=25.1.0", "types-beautifulsoup4~=4.12.0", - "types-cachetools~=5.5.0", + "types-cachetools~=6.2.0", "types-colorama~=0.4.15", "types-defusedxml~=0.7.0", - "types-deprecated~=1.2.15", - "types-docutils~=0.21.0", - "types-jsonschema~=4.23.0", - "types-flask-cors~=5.0.0", + "types-deprecated~=1.3.1", + "types-docutils~=0.22.3", + "types-jsonschema~=4.26.0", + "types-flask-cors~=6.0.0", "types-flask-migrate~=4.1.0", "types-gevent~=25.9.0", "types-greenlet~=3.3.0", "types-html5lib~=1.1.11", "types-markdown~=3.10.2", - "types-oauthlib~=3.2.0", + "types-oauthlib~=3.3.0", "types-objgraph~=3.6.0", "types-olefile~=0.47.0", "types-openpyxl~=3.1.5", "types-pexpect~=4.9.0", - "types-protobuf~=5.29.1", + "types-protobuf~=6.32.1", "types-psutil~=7.2.2", "types-psycopg2~=2.9.21", "types-pygments~=2.19.0", "types-pymysql~=1.1.0", "types-python-dateutil~=2.9.0", - "types-pywin32~=310.0.0", + "types-pywin32~=311.0.0", "types-pyyaml~=6.0.12", - "types-regex~=2024.11.6", + "types-regex~=2026.2.28", "types-shapely~=2.1.0", "types-simplejson>=3.20.0", "types-six>=1.17.0", @@ -161,7 +161,7 @@ dev = [ "types_pyOpenSSL>=24.1.0", "types_cffi>=1.17.0", "types_setuptools>=80.9.0", - "pandas-stubs~=2.2.3", + "pandas-stubs~=3.0.0", "scipy-stubs>=1.15.3.0", "types-python-http-client>=3.3.7.20240910", "import-linter>=2.3", diff --git a/api/services/account_service.py b/api/services/account_service.py index f0eac2a522..bd520f54cf 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1089,9 +1089,9 @@ class TenantService: ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: - ta.role = role + ta.role = TenantAccountRole(role) else: - ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) + ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole(role)) db.session.add(ta) db.session.commit() @@ -1319,10 +1319,10 @@ class TenantService: db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first() ) if current_owner_join: - current_owner_join.role = "admin" + current_owner_join.role = TenantAccountRole.ADMIN # Update the role of the target member - target_member_join.role = new_role + target_member_join.role = TenantAccountRole(new_role) db.session.commit() @staticmethod diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 06f4ccb90e..49ca273442 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -429,17 +429,18 @@ class AppDslService: # Set icon type icon_type_value = icon_type or app_data.get("icon_type") + resolved_icon_type: IconType if icon_type_value in [IconType.EMOJI, IconType.IMAGE, IconType.LINK]: - icon_type = icon_type_value + resolved_icon_type = IconType(icon_type_value) else: - icon_type = IconType.EMOJI + resolved_icon_type = IconType.EMOJI icon = icon or str(app_data.get("icon", "")) if app: # Update existing app app.name = name or app_data.get("name", app.name) app.description = description or app_data.get("description", app.description) - app.icon_type = icon_type + app.icon_type = resolved_icon_type app.icon = icon app.icon_background = icon_background or app_data.get("icon_background", app.icon_background) app.updated_by = account.id @@ -452,10 +453,10 @@ class AppDslService: app = App() app.id = str(uuid4()) app.tenant_id = account.current_tenant_id - app.mode = app_mode.value + app.mode = app_mode app.name = name or app_data.get("name", "") app.description = description or app_data.get("description", "") - app.icon_type = icon_type + app.icon_type = resolved_icon_type app.icon = icon app.icon_background = icon_background or app_data.get("icon_background", "#FFFFFF") app.enable_site = True @@ -549,7 +550,7 @@ class AppDslService: "kind": "app", "app": { "name": app_model.name, - "mode": app_model.mode, + "mode": app_model.mode.value if isinstance(app_model.mode, AppMode) else app_model.mode, "icon": app_model.icon if app_model.icon_type == "image" else "🤖", "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, "description": app_model.description, diff --git a/api/services/app_service.py b/api/services/app_service.py index aba8954f1a..b5e893c5b5 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -19,7 +19,7 @@ from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account -from models.model import App, AppMode, AppModelConfig, Site +from models.model import App, AppMode, AppModelConfig, IconType, Site from models.tools import ApiToolProvider from services.billing_service import BillingService from services.enterprise.enterprise_service import EnterpriseService @@ -254,7 +254,7 @@ class AppService: assert current_user is not None app.name = args["name"] app.description = args["description"] - app.icon_type = args["icon_type"] + app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None app.icon = args["icon"] app.icon_background = args["icon_background"] app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3a7d483a9d..c527c71d7b 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -254,7 +254,7 @@ class DatasetService: dataset.embedding_model_provider = embedding_model.provider if embedding_model else None dataset.embedding_model = embedding_model.model_name if embedding_model else None dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None - dataset.permission = permission or DatasetPermissionEnum.ONLY_ME + dataset.permission = DatasetPermissionEnum(permission) if permission else DatasetPermissionEnum.ONLY_ME dataset.provider = provider if summary_index_setting is not None: dataset.summary_index_setting = summary_index_setting diff --git a/api/services/file_service.py b/api/services/file_service.py index e08b78bf4c..ecb30faaa8 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -58,8 +58,9 @@ class FileService: # get file extension extension = os.path.splitext(filename)[1].lstrip(".").lower() - # check if filename contains invalid characters - if any(c in filename for c in ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]): + # Only reject path separators here. The original filename is stored as metadata, + # while the storage key is UUID-based. + if any(c in filename for c in ["/", "\\"]): raise ValueError("Filename contains invalid characters") if len(filename) > 200: diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index c00c76a826..d85b290534 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -13,6 +13,7 @@ from dify_graph.model_runtime.entities import LLMMode from extensions.ext_database import db from models import Account from models.dataset import Dataset, DatasetQuery +from models.enums import CreatorUserRole logger = logging.getLogger(__name__) @@ -98,7 +99,7 @@ class HitTestingService: content=json.dumps(dataset_queries), source="hit_testing", source_app_id=None, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) db.session.add(dataset_query) @@ -138,7 +139,7 @@ class HitTestingService: content=query, source="hit_testing", source_app_id=None, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index ce745a4679..b9a565ec17 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -36,6 +36,7 @@ from core.rag.entities.event import ( ) from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.node_resolution import LATEST_VERSION, get_workflow_node_type_classes_mapping from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, @@ -48,7 +49,6 @@ from dify_graph.graph_events.base import GraphNodeEventBase from dify_graph.node_events.base import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from dify_graph.runtime import VariablePool from dify_graph.system_variable import SystemVariable @@ -381,7 +381,7 @@ class RagPipelineService: """ # return default block config default_block_configs: list[dict[str, Any]] = [] - for node_type, node_class_mapping in NODE_TYPE_CLASSES_MAPPING.items(): + for node_type, node_class_mapping in get_workflow_node_type_classes_mapping().items(): node_class = node_class_mapping[LATEST_VERSION] filters = None if node_type is NodeType.HTTP_REQUEST: @@ -410,12 +410,13 @@ class RagPipelineService: :return: """ node_type_enum = NodeType(node_type) + node_mapping = get_workflow_node_type_classes_mapping() # return default block config - if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: + if node_type_enum not in node_mapping: return None - node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] + node_class = node_mapping[node_type_enum][LATEST_VERSION] final_filters = dict(filters) if filters else {} if node_type_enum is NodeType.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in final_filters: final_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] = build_http_request_config( diff --git a/api/services/retention/workflow_run/restore_archived_workflow_run.py b/api/services/retention/workflow_run/restore_archived_workflow_run.py index d4a6e87585..64dad7ba52 100644 --- a/api/services/retention/workflow_run/restore_archived_workflow_run.py +++ b/api/services/retention/workflow_run/restore_archived_workflow_run.py @@ -358,21 +358,19 @@ class WorkflowRunRestore: self, model: type[DeclarativeBase] | Any, ) -> tuple[set[str], set[str], set[str]]: - columns = list(model.__table__.columns) + table = model.__table__ + columns = list(table.columns) + autoincrement_column = getattr(table, "autoincrement_column", None) + + def has_insert_default(column: Any) -> bool: + # SQLAlchemy may set column.autoincrement to "auto" on non-PK columns. + # Only treat the resolved autoincrement column as DB-generated. + return column.default is not None or column.server_default is not None or column is autoincrement_column + column_names = {column.key for column in columns} - required_columns = { - column.key - for column in columns - if not column.nullable - and column.default is None - and column.server_default is None - and not column.autoincrement - } + required_columns = {column.key for column in columns if not column.nullable and not has_insert_default(column)} non_nullable_with_default = { - column.key - for column in columns - if not column.nullable - and (column.default is not None or column.server_default is not None or column.autoincrement) + column.key for column in columns if not column.nullable and has_insert_default(column) } return column_names, required_columns, non_nullable_with_default diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 4dd6c8107b..d0f4f27968 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -3,6 +3,7 @@ from typing import Union from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account +from models.enums import CreatorUserRole from models.model import App, EndUser from models.web import SavedMessage from services.message_service import MessageService @@ -54,7 +55,7 @@ class SavedMessageService: saved_message = SavedMessage( app_id=app_model.id, message_id=message.id, - created_by_role="account" if isinstance(user, Account) else "end_user", + created_by_role=CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER, created_by=user.id, ) diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index 8389ccbb34..88b640305d 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -1,14 +1,18 @@ import json import logging -from collections.abc import Mapping from datetime import datetime -from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.nodes import NodeType -from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig +from dify_graph.nodes.trigger_schedule.entities import ( + ScheduleConfig, + SchedulePlanUpdate, + TriggerScheduleNodeData, + VisualConfig, +) from dify_graph.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin @@ -176,26 +180,26 @@ class ScheduleService: return next_run_at @staticmethod - def to_schedule_config(node_config: Mapping[str, Any]) -> ScheduleConfig: + def to_schedule_config(node_config: NodeConfigDict) -> ScheduleConfig: """ Converts user-friendly visual schedule settings to cron expression. Maintains consistency with frontend UI expectations while supporting croniter's extended syntax. """ - node_data = node_config.get("data", {}) - mode = node_data.get("mode", "visual") - timezone = node_data.get("timezone", "UTC") - node_id = node_config.get("id", "start") + node_data = TriggerScheduleNodeData.model_validate(node_config["data"], from_attributes=True) + mode = node_data.mode + timezone = node_data.timezone + node_id = node_config["id"] cron_expression = None if mode == "cron": - cron_expression = node_data.get("cron_expression") + cron_expression = node_data.cron_expression if not cron_expression: raise ScheduleConfigError("Cron expression is required for cron mode") elif mode == "visual": - frequency = str(node_data.get("frequency")) + frequency = str(node_data.frequency or "") if not frequency: raise ScheduleConfigError("Frequency is required for visual mode") - visual_config = VisualConfig(**node_data.get("visual_config", {})) + visual_config = VisualConfig.model_validate(node_data.visual_config or {}) cron_expression = ScheduleService.visual_to_cron(frequency=frequency, visual_config=visual_config) if not cron_expression: raise ScheduleConfigError("Cron expression is required for visual mode") @@ -239,19 +243,21 @@ class ScheduleService: if node_data.get("type") != NodeType.TRIGGER_SCHEDULE.value: continue - mode = node_data.get("mode", "visual") - timezone = node_data.get("timezone", "UTC") node_id = node.get("id", "start") + trigger_data = TriggerScheduleNodeData.model_validate(node_data) + mode = trigger_data.mode + timezone = trigger_data.timezone cron_expression = None if mode == "cron": - cron_expression = node_data.get("cron_expression") + cron_expression = trigger_data.cron_expression if not cron_expression: raise ScheduleConfigError("Cron expression is required for cron mode") elif mode == "visual": - frequency = node_data.get("frequency") - visual_config_dict = node_data.get("visual_config", {}) - visual_config = VisualConfig(**visual_config_dict) + frequency = trigger_data.frequency + if not frequency: + raise ScheduleConfigError("Frequency is required for visual mode") + visual_config = VisualConfig.model_validate(trigger_data.visual_config or {}) cron_expression = ScheduleService.visual_to_cron(frequency, visual_config) else: raise ScheduleConfigError(f"Invalid schedule mode: {mode}") diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index f1f0d0ea84..2343bbbd3d 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -16,6 +16,7 @@ from core.trigger.debug.events import PluginTriggerDebugEvent from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData from extensions.ext_database import db @@ -41,7 +42,7 @@ class TriggerService: @classmethod def invoke_trigger_event( - cls, tenant_id: str, user_id: str, node_config: Mapping[str, Any], event: PluginTriggerDebugEvent + cls, tenant_id: str, user_id: str, node_config: NodeConfigDict, event: PluginTriggerDebugEvent ) -> TriggerInvokeEventResponse: """Invoke a trigger event.""" subscription: TriggerSubscription | None = TriggerProviderService.get_subscription_by_id( @@ -50,7 +51,7 @@ class TriggerService: ) if not subscription: raise ValueError("Subscription not found") - node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(node_config.get("data", {})) + node_data = TriggerEventNodeData.model_validate(node_config["data"], from_attributes=True) request = TriggerHttpRequestCachingService.get_request(event.request_id) payload = TriggerHttpRequestCachingService.get_payload(event.request_id) # invoke triger diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 285645edce..02977b934c 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -2,7 +2,7 @@ import json import logging import mimetypes import secrets -from collections.abc import Mapping +from collections.abc import Callable, Mapping, Sequence from typing import Any import orjson @@ -16,9 +16,16 @@ from werkzeug.exceptions import RequestEntityTooLarge from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.tool_file_manager import ToolFileManager +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType from dify_graph.file.models import FileTransferMethod -from dify_graph.variables.types import SegmentType +from dify_graph.nodes.trigger_webhook.entities import ( + ContentType, + WebhookBodyParameter, + WebhookData, + WebhookParameter, +) +from dify_graph.variables.types import ArrayValidation, SegmentType from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -57,7 +64,7 @@ class WebhookService: @classmethod def get_webhook_trigger_and_workflow( cls, webhook_id: str, is_debug: bool = False - ) -> tuple[WorkflowWebhookTrigger, Workflow, Mapping[str, Any]]: + ) -> tuple[WorkflowWebhookTrigger, Workflow, NodeConfigDict]: """Get webhook trigger, workflow, and node configuration. Args: @@ -135,7 +142,7 @@ class WebhookService: @classmethod def extract_and_validate_webhook_data( - cls, webhook_trigger: WorkflowWebhookTrigger, node_config: Mapping[str, Any] + cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict ) -> dict[str, Any]: """Extract and validate webhook data in a single unified process. @@ -153,7 +160,7 @@ class WebhookService: raw_data = cls.extract_webhook_data(webhook_trigger) # Validate HTTP metadata (method, content-type) - node_data = node_config.get("data", {}) + node_data = WebhookData.model_validate(node_config["data"], from_attributes=True) validation_result = cls._validate_http_metadata(raw_data, node_data) if not validation_result["valid"]: raise ValueError(validation_result["error"]) @@ -192,7 +199,7 @@ class WebhookService: content_type = cls._extract_content_type(dict(request.headers)) # Route to appropriate extractor based on content type - extractors = { + extractors: dict[str, Callable[[], tuple[dict[str, Any], dict[str, Any]]]] = { "application/json": cls._extract_json_body, "application/x-www-form-urlencoded": cls._extract_form_body, "multipart/form-data": lambda: cls._extract_multipart_body(webhook_trigger), @@ -214,7 +221,7 @@ class WebhookService: return data @classmethod - def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]: + def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]: """Process and validate webhook data according to node configuration. Args: @@ -230,18 +237,13 @@ class WebhookService: result = raw_data.copy() # Validate and process headers - cls._validate_required_headers(raw_data["headers"], node_data.get("headers", [])) + cls._validate_required_headers(raw_data["headers"], node_data.headers) # Process query parameters with type conversion and validation - result["query_params"] = cls._process_parameters( - raw_data["query_params"], node_data.get("params", []), is_form_data=True - ) + result["query_params"] = cls._process_parameters(raw_data["query_params"], node_data.params, is_form_data=True) # Process body parameters based on content type - configured_content_type = node_data.get("content_type", "application/json").lower() - result["body"] = cls._process_body_parameters( - raw_data["body"], node_data.get("body", []), configured_content_type - ) + result["body"] = cls._process_body_parameters(raw_data["body"], node_data.body, node_data.content_type) return result @@ -424,7 +426,11 @@ class WebhookService: @classmethod def _process_parameters( - cls, raw_params: dict[str, str], param_configs: list, is_form_data: bool = False + cls, + raw_params: dict[str, str], + param_configs: Sequence[WebhookParameter], + *, + is_form_data: bool = False, ) -> dict[str, Any]: """Process parameters with unified validation and type conversion. @@ -440,13 +446,13 @@ class WebhookService: ValueError: If required parameters are missing or validation fails """ processed = {} - configured_params = {config.get("name", ""): config for config in param_configs} + configured_params = {config.name: config for config in param_configs} # Process configured parameters for param_config in param_configs: - name = param_config.get("name", "") - param_type = param_config.get("type", SegmentType.STRING) - required = param_config.get("required", False) + name = param_config.name + param_type = param_config.type + required = param_config.required # Check required parameters if required and name not in raw_params: @@ -465,7 +471,10 @@ class WebhookService: @classmethod def _process_body_parameters( - cls, raw_body: dict[str, Any], body_configs: list, content_type: str + cls, + raw_body: dict[str, Any], + body_configs: Sequence[WebhookBodyParameter], + content_type: ContentType, ) -> dict[str, Any]: """Process body parameters based on content type and configuration. @@ -480,25 +489,28 @@ class WebhookService: Raises: ValueError: If required body parameters are missing or validation fails """ - if content_type in ["text/plain", "application/octet-stream"]: - # For text/plain and octet-stream, validate required content exists - if body_configs and any(config.get("required", False) for config in body_configs): - raw_content = raw_body.get("raw") - if not raw_content: - raise ValueError(f"Required body content missing for {content_type} request") - return raw_body + match content_type: + case ContentType.TEXT | ContentType.BINARY: + # For text/plain and octet-stream, validate required content exists + if body_configs and any(config.required for config in body_configs): + raw_content = raw_body.get("raw") + if not raw_content: + raise ValueError(f"Required body content missing for {content_type} request") + return raw_body + case _: + pass # For structured data (JSON, form-data, etc.) processed = {} - configured_params = {config.get("name", ""): config for config in body_configs} + configured_params: dict[str, WebhookBodyParameter] = {config.name: config for config in body_configs} for body_config in body_configs: - name = body_config.get("name", "") - param_type = body_config.get("type", SegmentType.STRING) - required = body_config.get("required", False) + name = body_config.name + param_type = body_config.type + required = body_config.required # Handle file parameters for multipart data - if param_type == SegmentType.FILE and content_type == "multipart/form-data": + if param_type == SegmentType.FILE and content_type == ContentType.FORM_DATA: # File validation is handled separately in extract phase continue @@ -508,7 +520,7 @@ class WebhookService: if name in raw_body: raw_value = raw_body[name] - is_form_data = content_type in ["application/x-www-form-urlencoded", "multipart/form-data"] + is_form_data = content_type in [ContentType.FORM_URLENCODED, ContentType.FORM_DATA] processed[name] = cls._validate_and_convert_value(name, raw_value, param_type, is_form_data) # Include unconfigured parameters @@ -519,7 +531,9 @@ class WebhookService: return processed @classmethod - def _validate_and_convert_value(cls, param_name: str, value: Any, param_type: str, is_form_data: bool) -> Any: + def _validate_and_convert_value( + cls, param_name: str, value: Any, param_type: SegmentType | str, is_form_data: bool + ) -> Any: """Unified validation and type conversion for parameter values. Args: @@ -532,7 +546,8 @@ class WebhookService: Any: The validated and converted value Raises: - ValueError: If validation or conversion fails + ValueError: If validation or conversion fails. The original validation + error is preserved as ``__cause__`` for debugging. """ try: if is_form_data: @@ -542,10 +557,10 @@ class WebhookService: # JSON data should already be in correct types, just validate return cls._validate_json_value(param_name, value, param_type) except Exception as e: - raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}") + raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}") from e @classmethod - def _convert_form_value(cls, param_name: str, value: str, param_type: str) -> Any: + def _convert_form_value(cls, param_name: str, value: str, param_type: SegmentType | str) -> Any: """Convert form data string values to specified types. Args: @@ -576,7 +591,7 @@ class WebhookService: raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'") @classmethod - def _validate_json_value(cls, param_name: str, value: Any, param_type: str) -> Any: + def _validate_json_value(cls, param_name: str, value: Any, param_type: SegmentType | str) -> Any: """Validate JSON values against expected types. Args: @@ -590,43 +605,43 @@ class WebhookService: Raises: ValueError: If the value type doesn't match the expected type """ - type_validators = { - SegmentType.STRING: (lambda v: isinstance(v, str), "string"), - SegmentType.NUMBER: (lambda v: isinstance(v, (int, float)), "number"), - SegmentType.BOOLEAN: (lambda v: isinstance(v, bool), "boolean"), - SegmentType.OBJECT: (lambda v: isinstance(v, dict), "object"), - SegmentType.ARRAY_STRING: ( - lambda v: isinstance(v, list) and all(isinstance(item, str) for item in v), - "array of strings", - ), - SegmentType.ARRAY_NUMBER: ( - lambda v: isinstance(v, list) and all(isinstance(item, (int, float)) for item in v), - "array of numbers", - ), - SegmentType.ARRAY_BOOLEAN: ( - lambda v: isinstance(v, list) and all(isinstance(item, bool) for item in v), - "array of booleans", - ), - SegmentType.ARRAY_OBJECT: ( - lambda v: isinstance(v, list) and all(isinstance(item, dict) for item in v), - "array of objects", - ), - } - - validator_info = type_validators.get(SegmentType(param_type)) - if not validator_info: - logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name) + param_type_enum = cls._coerce_segment_type(param_type, param_name=param_name) + if param_type_enum is None: return value - validator, expected_type = validator_info - if not validator(value): + if not param_type_enum.is_valid(value, array_validation=ArrayValidation.ALL): actual_type = type(value).__name__ + expected_type = cls._expected_type_label(param_type_enum) raise ValueError(f"Expected {expected_type}, got {actual_type}") return value @classmethod - def _validate_required_headers(cls, headers: dict[str, Any], header_configs: list) -> None: + def _coerce_segment_type(cls, param_type: SegmentType | str, *, param_name: str) -> SegmentType | None: + if isinstance(param_type, SegmentType): + return param_type + try: + return SegmentType(param_type) + except Exception: + logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name) + return None + + @staticmethod + def _expected_type_label(param_type: SegmentType) -> str: + match param_type: + case SegmentType.ARRAY_STRING: + return "array of strings" + case SegmentType.ARRAY_NUMBER: + return "array of numbers" + case SegmentType.ARRAY_BOOLEAN: + return "array of booleans" + case SegmentType.ARRAY_OBJECT: + return "array of objects" + case _: + return param_type.value + + @classmethod + def _validate_required_headers(cls, headers: dict[str, Any], header_configs: Sequence[WebhookParameter]) -> None: """Validate required headers are present. Args: @@ -639,14 +654,14 @@ class WebhookService: headers_lower = {k.lower(): v for k, v in headers.items()} headers_sanitized = {cls._sanitize_key(k).lower(): v for k, v in headers.items()} for header_config in header_configs: - if header_config.get("required", False): - header_name = header_config.get("name", "") + if header_config.required: + header_name = header_config.name sanitized_name = cls._sanitize_key(header_name).lower() if header_name.lower() not in headers_lower and sanitized_name not in headers_sanitized: raise ValueError(f"Required header missing: {header_name}") @classmethod - def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]: + def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]: """Validate HTTP method and content-type. Args: @@ -657,13 +672,13 @@ class WebhookService: dict[str, Any]: Validation result with 'valid' key and optional 'error' key """ # Validate HTTP method - configured_method = node_data.get("method", "get").upper() + configured_method = node_data.method.value.upper() request_method = webhook_data["method"].upper() if configured_method != request_method: return cls._validation_error(f"HTTP method mismatch. Expected {configured_method}, got {request_method}") # Validate Content-type - configured_content_type = node_data.get("content_type", "application/json").lower() + configured_content_type = node_data.content_type.value.lower() request_content_type = cls._extract_content_type(webhook_data["headers"]) if configured_content_type != request_content_type: @@ -788,7 +803,7 @@ class WebhookService: raise @classmethod - def generate_webhook_response(cls, node_config: Mapping[str, Any]) -> tuple[dict[str, Any], int]: + def generate_webhook_response(cls, node_config: NodeConfigDict) -> tuple[dict[str, Any], int]: """Generate HTTP response based on node configuration. Args: @@ -797,11 +812,11 @@ class WebhookService: Returns: tuple[dict[str, Any], int]: Response data and HTTP status code """ - node_data = node_config.get("data", {}) + node_data = WebhookData.model_validate(node_config["data"], from_attributes=True) # Get configured status code and response body - status_code = node_data.get("status_code", 200) - response_body = node_data.get("response_body", "") + status_code = node_data.status_code + response_body = node_data.response_body # Parse response body as JSON if it's valid JSON, otherwise return as text try: diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 560aec2330..e028e3e5e3 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -7,6 +7,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account +from models.enums import CreatorUserRole from models.model import App, EndUser from models.web import PinnedConversation from services.conversation_service import ConversationService @@ -84,7 +85,7 @@ class WebConversationService: pinned_conversation = PinnedConversation( app_id=app_model.id, conversation_id=conversation.id, - created_by_role="account" if isinstance(user, Account) else "end_user", + created_by_role=CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER, created_by=user.id, ) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 0153046acc..3acbc93678 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -24,7 +24,7 @@ from events.app_event import app_was_created from extensions.ext_database import db from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, IconType from models.workflow import Workflow, WorkflowType @@ -72,7 +72,7 @@ class WorkflowConverter: new_app.tenant_id = app_model.tenant_id new_app.name = name or app_model.name + "(workflow)" new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW - new_app.icon_type = icon_type or app_model.icon_type + new_app.icon_type = IconType(icon_type) if icon_type else app_model.icon_type new_app.icon = icon or app_model.icon new_app.icon_background = icon_background or app_model.icon_background new_app.enable_site = app_model.enable_site diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 6d462b60b9..5b24c356c2 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -14,8 +14,10 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import HumanInputFormRepositoryImpl +from core.workflow.node_resolution import LATEST_VERSION, get_workflow_node_type_classes_mapping from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams, WorkflowNodeExecution +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.errors import WorkflowNodeRunFailedError @@ -33,7 +35,6 @@ from dify_graph.nodes.human_input.entities import ( ) from dify_graph.nodes.human_input.enums import HumanInputFormKind from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from dify_graph.nodes.start.entities import StartNodeData from dify_graph.repositories.human_input_form_repository import FormCreateParams from dify_graph.runtime import GraphRuntimeState, VariablePool @@ -618,7 +619,7 @@ class WorkflowService: """ # return default block config default_block_configs: list[Mapping[str, object]] = [] - for node_type, node_class_mapping in NODE_TYPE_CLASSES_MAPPING.items(): + for node_type, node_class_mapping in get_workflow_node_type_classes_mapping().items(): node_class = node_class_mapping[LATEST_VERSION] filters = None if node_type is NodeType.HTTP_REQUEST: @@ -649,12 +650,13 @@ class WorkflowService: :return: """ node_type_enum = NodeType(node_type) + node_mapping = get_workflow_node_type_classes_mapping() # return default block config - if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: + if node_type_enum not in node_mapping: return {} - node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] + node_class = node_mapping[node_type_enum][LATEST_VERSION] resolved_filters = dict(filters) if filters else {} if node_type_enum is NodeType.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in resolved_filters: resolved_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] = build_http_request_config( @@ -693,7 +695,7 @@ class WorkflowService: node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) - node_data = node_config.get("data", {}) + node_data = node_config["data"] if node_type.is_start_node: with Session(bind=db.engine) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) @@ -703,7 +705,7 @@ class WorkflowService: workflow=draft_workflow, ) if node_type is NodeType.START: - start_data = StartNodeData.model_validate(node_data) + start_data = StartNodeData.model_validate(node_data, from_attributes=True) user_inputs = _rebuild_file_for_user_inputs_in_start_node( tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs ) @@ -941,7 +943,7 @@ class WorkflowService: if node_type is not NodeType.HUMAN_INPUT: raise ValueError("Node type must be human-input.") - node_data = HumanInputNodeData.model_validate(node_config.get("data", {})) + node_data = HumanInputNodeData.model_validate(node_config["data"], from_attributes=True) delivery_method = self._resolve_human_input_delivery_method( node_data=node_data, delivery_method_id=delivery_method_id, @@ -951,7 +953,7 @@ class WorkflowService: delivery_method = apply_debug_email_recipient( delivery_method, enabled=True, - user_id=account.id or "", + user_id=account.id, ) variable_pool = self._build_human_input_variable_pool( @@ -1059,7 +1061,7 @@ class WorkflowService: *, workflow: Workflow, account: Account, - node_config: Mapping[str, Any], + node_config: NodeConfigDict, variable_pool: VariablePool, ) -> HumanInputNode: graph_init_params = GraphInitParams( @@ -1079,7 +1081,7 @@ class WorkflowService: start_at=time.perf_counter(), ) node = HumanInputNode( - id=node_config.get("id", str(uuid.uuid4())), + id=node_config["id"], config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, @@ -1092,7 +1094,7 @@ class WorkflowService: *, app_model: App, workflow: Workflow, - node_config: Mapping[str, Any], + node_config: NodeConfigDict, manual_inputs: Mapping[str, Any], ) -> VariablePool: with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index d06b8c980b..e7f4e37c75 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -164,7 +164,7 @@ def _record_trigger_failure_log( elapsed_time=0.0, total_tokens=0, total_steps=0, - created_by_role=created_by_role.value, + created_by_role=created_by_role, created_by=created_by, created_at=now, finished_at=now, @@ -179,7 +179,7 @@ def _record_trigger_failure_log( workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value, - created_by_role=created_by_role.value, + created_by_role=created_by_role, created_by=created_by, ) session.add(workflow_app_log) @@ -212,7 +212,7 @@ def _record_trigger_failure_log( error=error_message, queue_name=queue_name, retry_count=0, - created_by_role=created_by_role.value, + created_by_role=created_by_role, created_by=created_by, triggered_at=now, finished_at=now, diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index db8721e90b..f41118e592 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -94,13 +94,15 @@ def _create_workflow_run_from_execution( workflow_run.tenant_id = tenant_id workflow_run.app_id = app_id workflow_run.workflow_id = execution.workflow_id - workflow_run.type = execution.workflow_type.value - workflow_run.triggered_from = triggered_from.value + from models.workflow import WorkflowType as ModelWorkflowType + + workflow_run.type = ModelWorkflowType(execution.workflow_type.value) + workflow_run.triggered_from = triggered_from workflow_run.version = execution.workflow_version json_converter = WorkflowRuntimeTypeConverter() workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph)) workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) - workflow_run.status = execution.status.value + workflow_run.status = execution.status workflow_run.outputs = ( json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" ) @@ -108,7 +110,7 @@ def _create_workflow_run_from_execution( workflow_run.elapsed_time = execution.elapsed_time workflow_run.total_tokens = execution.total_tokens workflow_run.total_steps = execution.total_steps - workflow_run.created_by_role = creator_user_role.value + workflow_run.created_by_role = creator_user_role workflow_run.created_by = creator_user_id workflow_run.created_at = execution.started_at workflow_run.finished_at = execution.finished_at @@ -121,7 +123,7 @@ def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: Wo Update a WorkflowRun database model from a WorkflowExecution domain entity. """ json_converter = WorkflowRuntimeTypeConverter() - workflow_run.status = execution.status.value + workflow_run.status = execution.status workflow_run.outputs = ( json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" ) diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index 3f607dc55e..eaafbf99e3 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -98,7 +98,7 @@ def _create_node_execution_from_domain( node_execution.tenant_id = tenant_id node_execution.app_id = app_id node_execution.workflow_id = execution.workflow_id - node_execution.triggered_from = triggered_from.value + node_execution.triggered_from = triggered_from node_execution.workflow_run_id = execution.workflow_execution_id node_execution.index = execution.index node_execution.predecessor_node_id = execution.predecessor_node_id @@ -128,7 +128,7 @@ def _create_node_execution_from_domain( node_execution.status = execution.status.value node_execution.error = execution.error node_execution.elapsed_time = execution.elapsed_time - node_execution.created_by_role = creator_user_role.value + node_execution.created_by_role = creator_user_role node_execution.created_by = creator_user_id node_execution.created_at = execution.created_at node_execution.finished_at = execution.finished_at diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 39effbab58..37f8830482 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -60,7 +60,6 @@ VECTOR_STORE=weaviate # Weaviate configuration WEAVIATE_ENDPOINT=http://localhost:8080 WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih -WEAVIATE_GRPC_ENABLED=false WEAVIATE_BATCH_SIZE=100 WEAVIATE_TOKENIZATION=word diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index 498ac56d5d..afb6938baa 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -165,7 +165,7 @@ class TestChatMessageApiPermissions: agent_thoughts=[], message_files=[], message_metadata_dict={}, - status="success", + status="normal", error="", parent_message_id=None, ) diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index f691113511..347fa9c9ed 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -189,6 +189,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" + from dify_graph.enums import NodeType from dify_graph.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, @@ -209,6 +210,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): # Create node data with custom auth and empty api_key node_data = HttpRequestNodeData( + type=NodeType.HTTP_REQUEST, title="http", desc="", url="http://example.com", diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 23cb56d2a5..8a4fb8eda4 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -1,6 +1,6 @@ import time import uuid -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager @@ -87,17 +87,20 @@ def test_tool_variable_invoke(): } ) - ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"}) + with patch.object( + ToolParameterConfigurationManager, + "decrypt_tool_parameters", + return_value={"format": "%Y-%m-%d %H:%M:%S"}, + ): + node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1") - node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1") - - # execute node - result = node._run() - for item in result: - if isinstance(item, StreamCompletedEvent): - assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.node_run_result.outputs is not None - assert item.node_run_result.outputs.get("text") is not None + # execute node + result = node._run() + for item in result: + if isinstance(item, StreamCompletedEvent): + assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.node_run_result.outputs is not None + assert item.node_run_result.outputs.get("text") is not None def test_tool_mixed_invoke(): @@ -121,12 +124,15 @@ def test_tool_mixed_invoke(): } ) - ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"}) - - # execute node - result = node._run() - for item in result: - if isinstance(item, StreamCompletedEvent): - assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.node_run_result.outputs is not None - assert item.node_run_result.outputs.get("text") is not None + with patch.object( + ToolParameterConfigurationManager, + "decrypt_tool_parameters", + return_value={"format": "%Y-%m-%d %H:%M:%S"}, + ): + # execute node + result = node._run() + for item in result: + if isinstance(item, StreamCompletedEvent): + assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.node_run_result.outputs is not None + assert item.node_run_result.outputs.get("text") is not None diff --git a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py index cdf390b327..a60159c66a 100644 --- a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py @@ -18,7 +18,7 @@ from faker import Faker from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue from extensions.ext_redis import redis_client -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus @dataclass @@ -47,7 +47,7 @@ class TestTenantIsolatedTaskQueueIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() @@ -55,7 +55,7 @@ class TestTenantIsolatedTaskQueueIntegration: # Create tenant tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -101,7 +101,7 @@ class TestTenantIsolatedTaskQueueIntegration: # Create second tenant tenant2 = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant2) db_session_with_containers.commit() @@ -410,7 +410,7 @@ class TestTenantIsolatedTaskQueueCompatibility: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() @@ -418,7 +418,7 @@ class TestTenantIsolatedTaskQueueCompatibility: # Create tenant tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/models/test_app_model_config.py b/api/tests/test_containers_integration_tests/models/test_app_model_config.py new file mode 100644 index 0000000000..e8b36097e1 --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_app_model_config.py @@ -0,0 +1,32 @@ +""" +Integration tests for AppModelConfig using testcontainers. + +These tests validate database-backed model behavior without mocking SQLAlchemy queries. +""" + +from uuid import uuid4 + +from sqlalchemy.orm import Session + +from models.model import AppModelConfig + + +class TestAppModelConfig: + """Integration tests for AppModelConfig.""" + + def test_annotation_reply_dict_disabled_without_setting(self, db_session_with_containers: Session) -> None: + """Return disabled annotation reply dict when no AppAnnotationSetting exists.""" + # Arrange + config = AppModelConfig(app_id=str(uuid4())) + db_session_with_containers.add(config) + db_session_with_containers.commit() + + # Act + result = config.annotation_reply_dict + + # Assert + assert result == {"enabled": False} + + # Cleanup + db_session_with_containers.delete(config) + db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 9354a3ac35..cc9596d15f 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -3331,7 +3331,7 @@ class TestRegisterService: TenantService.create_tenant_member(tenant, account, role="normal") # Change tenant status to non-normal - tenant.status = "suspended" + tenant.status = "archive" db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 5155d50b0e..5b1a4790f5 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -2,6 +2,7 @@ import uuid from unittest.mock import ANY, MagicMock, patch import pytest +import sqlalchemy as sa from faker import Faker from sqlalchemy.orm import Session @@ -492,20 +493,20 @@ class TestAppGenerateService: ) # Manually set invalid mode after creation + # With EnumText, invalid values are rejected at the DB level during flush, + # raising StatementError wrapping ValueError app.mode = "invalid_mode" # Setup test arguments args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} - # Execute the method under test and expect ValueError - with pytest.raises(ValueError) as exc_info: + # Execute the method under test and expect either ValueError (direct) or + # StatementError (from EnumText validation during autoflush) + with pytest.raises((ValueError, sa.exc.StatementError)): AppGenerateService.generate( app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True ) - # Verify error message - assert "Invalid app mode" in str(exc_info.value) - def test_generate_with_workflow_id_format_error( self, db_session_with_containers: Session, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_file_service.py b/api/tests/test_containers_integration_tests/services/test_file_service.py index 6712fe8454..50f5b7a8c0 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service.py @@ -263,6 +263,27 @@ class TestFileService: user=account, ) + def test_upload_file_allows_regular_punctuation_in_filename( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): + """ + Test file upload allows punctuation that is safe when stored as metadata. + """ + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + filename = 'candidate?resume for "dify"|v2:.txt' + content = b"test content" + mimetype = "text/plain" + + upload_file = FileService(engine).upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=account, + ) + + assert upload_file.name == filename + def test_upload_file_filename_too_long( self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index cc403ef5a2..dd743d46c2 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -163,7 +163,7 @@ class TestSavedMessageService: answer_unit_price=0.002, total_price=0.003, currency="USD", - status="success", + status="normal", ) db_session_with_containers.add(message) diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index f91e6efb10..970da98c55 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -173,7 +173,7 @@ class TestWebhookService: assert workflow.app_id == test_data["app"].id assert node_config is not None assert node_config["id"] == "webhook_node" - assert node_config["data"]["title"] == "Test Webhook" + assert node_config["data"].title == "Test Webhook" def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers): """Test webhook trigger not found scenario.""" diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index bfb23bac68..d8b43efeba 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -62,7 +62,7 @@ class TestWorkflowService: tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="active", + status="normal", ) tenant.id = account.current_tenant_id tenant.created_at = fake.date_time_this_year() @@ -1090,20 +1090,19 @@ class TestWorkflowService: This test ensures that the service correctly handles feature validation for unsupported app modes, preventing invalid operations. + With EnumText, invalid values are rejected at the DB level during flush, + raising StatementError wrapping ValueError. """ # Arrange fake = Faker() app = self._create_test_app(db_session_with_containers, fake) app.mode = "invalid_mode" # Invalid mode - db_session_with_containers.commit() + # Act & Assert - EnumText validation rejects invalid values at DB flush + import sqlalchemy as sa - workflow_service = WorkflowService() - features = {"test": "value"} - - # Act & Assert - with pytest.raises(ValueError, match="Invalid app mode: invalid_mode"): - workflow_service.validate_features_structure(app_model=app, features=features) + with pytest.raises((ValueError, sa.exc.StatementError)): + db_session_with_containers.commit() def test_update_workflow_success(self, db_session_with_containers: Session): """ diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 8eb881258a..41d9fc8a29 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -110,7 +110,7 @@ class TestCleanDatasetTask: tenant = Tenant( name=fake.company(), plan="basic", - status="active", + status="normal", ) db_session_with_containers.add(tenant) diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index bc0ed3bd2b..69ed5b632d 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -48,7 +48,7 @@ class TestDeleteSegmentFromIndexTask: Tenant: Created test tenant instance """ fake = fake or Faker() - tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active") + tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="normal") tenant.id = fake.uuid4() tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 8f47b48ae2..6f7d2c28b5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -65,7 +65,7 @@ class TestDisableSegmentsFromIndexTask: tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="active", + status="normal", ) tenant.id = account.tenant_id tenant.created_at = fake.date_time_this_year() diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py index 3cdec70df7..c0ddc27286 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py @@ -118,7 +118,7 @@ class TestSendEmailCodeLoginMailTask: tenant = Tenant( name=fake.company(), plan="basic", - status="active", + status="normal", ) db_session_with_containers.add(tenant) diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index cf52980e57..d6933e2180 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -25,7 +25,8 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "300") # Custom value for testing # load dotenv file with pydantic-settings - config = DifyConfig() + # Disable `.env` loading to ensure test stability across environments + config = DifyConfig(_env_file=None) # constant values assert config.COMMIT_SHA == "" @@ -59,7 +60,8 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("DB_PORT", "5432") monkeypatch.setenv("DB_DATABASE", "dify") - config = DifyConfig() + # Disable `.env` loading to ensure test stability across environments + config = DifyConfig(_env_file=None) # Verify default timeout values assert config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT == 10 @@ -86,7 +88,8 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("WEB_API_CORS_ALLOW_ORIGINS", "http://127.0.0.1:3000,*") monkeypatch.setenv("CODE_EXECUTION_ENDPOINT", "http://127.0.0.1:8194/") - flask_app.config.from_mapping(DifyConfig().model_dump()) # pyright: ignore + # Disable `.env` loading to ensure test stability across environments + flask_app.config.from_mapping(DifyConfig(_env_file=None).model_dump()) # pyright: ignore config = flask_app.config # configs read from pydantic-settings diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py index 55fb038156..726c0a5cf3 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py @@ -51,7 +51,7 @@ def bypass_decorators(mocker): ) mocker.patch( "controllers.console.datasets.hit_testing.cloud_edition_billing_rate_limit_check", - return_value=lambda *_: (lambda f: f), + return_value=lambda *_: lambda f: f, ) diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index c3a6522e6d..6b5c304884 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -48,7 +48,7 @@ def make_message(): msg.query = "hello" msg.re_sign_file_url_answer = "" msg.user_feedback = MagicMock(rating=None) - msg.status = "success" + msg.status = "normal" msg.error = None return msg diff --git a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py index f6db55db5b..eb19243225 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py @@ -200,10 +200,13 @@ class TestPluginUploadFromPkgApi: app.test_request_context("/", data=data, content_type="multipart/form-data"), patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 0), + patch("controllers.console.workspace.plugin.PluginService.upload_pkg") as upload_pkg_mock, ): with pytest.raises(ValueError): method(api) + upload_pkg_mock.assert_not_called() + class TestPluginInstallFromPkgApi: def test_install_from_pkg(self, app): @@ -444,10 +447,13 @@ class TestPluginUploadFromBundleApi: ), patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_BUNDLE_SIZE", 0), + patch("controllers.console.workspace.plugin.PluginService.upload_bundle") as upload_bundle_mock, ): with pytest.raises(ValueError): method(api) + upload_bundle_mock.assert_not_called() + class TestPluginInstallFromGithubApi: def test_success(self, app): diff --git a/api/tests/unit_tests/core/model_runtime/__init__.py b/api/tests/unit_tests/controllers/inner_api/__init__.py similarity index 100% rename from api/tests/unit_tests/core/model_runtime/__init__.py rename to api/tests/unit_tests/controllers/inner_api/__init__.py diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/__init__.py b/api/tests/unit_tests/controllers/inner_api/plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py new file mode 100644 index 0000000000..844f04fe72 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py @@ -0,0 +1,313 @@ +""" +Unit tests for inner_api plugin endpoints + +Tests endpoint structure (method existence) for all plugin APIs, plus +handler-level logic tests for representative non-streaming endpoints. +Auth/setup decorators are tested separately in test_auth_wraps.py; +handler tests use inspect.unwrap() to bypass them. +""" + +import inspect +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.inner_api.plugin.plugin import ( + PluginFetchAppInfoApi, + PluginInvokeAppApi, + PluginInvokeEncryptApi, + PluginInvokeLLMApi, + PluginInvokeLLMWithStructuredOutputApi, + PluginInvokeModerationApi, + PluginInvokeParameterExtractorNodeApi, + PluginInvokeQuestionClassifierNodeApi, + PluginInvokeRerankApi, + PluginInvokeSpeech2TextApi, + PluginInvokeSummaryApi, + PluginInvokeTextEmbeddingApi, + PluginInvokeToolApi, + PluginInvokeTTSApi, + PluginUploadFileRequestApi, +) + + +def _extract_raw_post(cls): + """Extract the raw post() method from a plugin endpoint class. + + Plugin endpoint methods are wrapped by several decorators (get_user_tenant, + setup_required, plugin_inner_api_only, plugin_data). These decorators + use @wraps where possible. This helper ensures we retrieve the original + post(self, user_model, tenant_model, payload) function by unwrapping + and, if necessary, walking the closure of the innermost wrapper. + """ + bottom = inspect.unwrap(cls.post) + + # If unwrap() didn't get us to the raw function (e.g. if a decorator + # missed @wraps), try to extract it from the closure if it looks like + # a plugin_data or similar wrapper that closes over 'view_func'. + if hasattr(bottom, "__code__") and "view_func" in bottom.__code__.co_freevars: + try: + idx = bottom.__code__.co_freevars.index("view_func") + return bottom.__closure__[idx].cell_contents + except (AttributeError, TypeError, IndexError): + pass + + return bottom + + +class TestPluginInvokeLLMApi: + """Test PluginInvokeLLMApi endpoint structure""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeLLMApi() + + def test_has_post_method(self, api_instance): + """Test that endpoint has post method""" + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeLLMWithStructuredOutputApi: + """Test PluginInvokeLLMWithStructuredOutputApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeLLMWithStructuredOutputApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeTextEmbeddingApi: + """Test PluginInvokeTextEmbeddingApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeTextEmbeddingApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeRerankApi: + """Test PluginInvokeRerankApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeRerankApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeTTSApi: + """Test PluginInvokeTTSApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeTTSApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeSpeech2TextApi: + """Test PluginInvokeSpeech2TextApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeSpeech2TextApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeModerationApi: + """Test PluginInvokeModerationApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeModerationApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeToolApi: + """Test PluginInvokeToolApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeToolApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeParameterExtractorNodeApi: + """Test PluginInvokeParameterExtractorNodeApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeParameterExtractorNodeApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeQuestionClassifierNodeApi: + """Test PluginInvokeQuestionClassifierNodeApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeQuestionClassifierNodeApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeAppApi: + """Test PluginInvokeAppApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeAppApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeEncryptApi: + """Test PluginInvokeEncryptApi endpoint structure and handler logic""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeEncryptApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.plugin.plugin.PluginEncrypter") + def test_post_returns_encrypted_data(self, mock_encrypter, api_instance, app: Flask): + """Test that post() delegates to PluginEncrypter and returns model_dump output""" + # Arrange + mock_encrypter.invoke_encrypt.return_value = {"encrypted": "data"} + mock_tenant = MagicMock() + mock_user = MagicMock() + mock_payload = MagicMock() + + # Act — extract raw post() bypassing all decorators including plugin_data + raw_post = _extract_raw_post(PluginInvokeEncryptApi) + result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload) + + # Assert + mock_encrypter.invoke_encrypt.assert_called_once_with(mock_tenant, mock_payload) + assert result["data"] == {"encrypted": "data"} + assert result.get("error") == "" + + @patch("controllers.inner_api.plugin.plugin.PluginEncrypter") + def test_post_returns_error_on_exception(self, mock_encrypter, api_instance, app: Flask): + """Test that post() catches exceptions and returns error response""" + # Arrange + mock_encrypter.invoke_encrypt.side_effect = RuntimeError("encrypt failed") + mock_tenant = MagicMock() + mock_user = MagicMock() + mock_payload = MagicMock() + + # Act + raw_post = _extract_raw_post(PluginInvokeEncryptApi) + result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload) + + # Assert + assert "encrypt failed" in result["error"] + + +class TestPluginInvokeSummaryApi: + """Test PluginInvokeSummaryApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeSummaryApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginUploadFileRequestApi: + """Test PluginUploadFileRequestApi endpoint structure and handler logic""" + + @pytest.fixture + def api_instance(self): + return PluginUploadFileRequestApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.plugin.plugin.get_signed_file_url_for_plugin") + def test_post_returns_signed_url(self, mock_get_url, api_instance, app: Flask): + """Test that post() generates a signed URL and returns it""" + # Arrange + mock_get_url.return_value = "https://storage.example.com/signed-upload-url" + mock_tenant = MagicMock() + mock_tenant.id = "tenant-id" + mock_user = MagicMock() + mock_user.id = "user-id" + mock_payload = MagicMock() + mock_payload.filename = "test.pdf" + mock_payload.mimetype = "application/pdf" + + # Act + raw_post = _extract_raw_post(PluginUploadFileRequestApi) + result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload) + + # Assert + mock_get_url.assert_called_once_with( + filename="test.pdf", mimetype="application/pdf", tenant_id="tenant-id", user_id="user-id" + ) + assert result["data"]["url"] == "https://storage.example.com/signed-upload-url" + + +class TestPluginFetchAppInfoApi: + """Test PluginFetchAppInfoApi endpoint structure and handler logic""" + + @pytest.fixture + def api_instance(self): + return PluginFetchAppInfoApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.plugin.plugin.PluginAppBackwardsInvocation") + def test_post_returns_app_info(self, mock_invocation, api_instance, app: Flask): + """Test that post() fetches app info and returns it""" + # Arrange + mock_invocation.fetch_app_info.return_value = {"app_name": "My App", "mode": "chat"} + mock_tenant = MagicMock() + mock_tenant.id = "tenant-id" + mock_user = MagicMock() + mock_payload = MagicMock() + mock_payload.app_id = "app-123" + + # Act + raw_post = _extract_raw_post(PluginFetchAppInfoApi) + result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload) + + # Assert + mock_invocation.fetch_app_info.assert_called_once_with("app-123", "tenant-id") + assert result["data"] == {"app_name": "My App", "mode": "chat"} diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py new file mode 100644 index 0000000000..6de07a23e5 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py @@ -0,0 +1,305 @@ +""" +Unit tests for inner_api plugin decorators +""" + +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.inner_api.plugin.wraps import ( + TenantUserPayload, + get_user, + get_user_tenant, + plugin_data, +) + + +class TestTenantUserPayload: + """Test TenantUserPayload Pydantic model""" + + def test_valid_payload(self): + """Test valid payload passes validation""" + data = {"tenant_id": "tenant123", "user_id": "user456"} + payload = TenantUserPayload.model_validate(data) + assert payload.tenant_id == "tenant123" + assert payload.user_id == "user456" + + def test_missing_tenant_id(self): + """Test missing tenant_id raises ValidationError""" + with pytest.raises(ValidationError): + TenantUserPayload.model_validate({"user_id": "user456"}) + + def test_missing_user_id(self): + """Test missing user_id raises ValidationError""" + with pytest.raises(ValidationError): + TenantUserPayload.model_validate({"tenant_id": "tenant123"}) + + +class TestGetUser: + """Test get_user function""" + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_return_existing_user_by_id(self, mock_db, mock_session_class, mock_enduser_class, app: Flask): + """Test returning existing user when found by ID""" + # Arrange + mock_user = MagicMock() + mock_user.id = "user123" + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.query.return_value.where.return_value.first.return_value = mock_user + + # Act + with app.app_context(): + result = get_user("tenant123", "user123") + + # Assert + assert result == mock_user + mock_session.query.assert_called_once() + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_return_existing_anonymous_user_by_session_id( + self, mock_db, mock_session_class, mock_enduser_class, app: Flask + ): + """Test returning existing anonymous user by session_id""" + # Arrange + mock_user = MagicMock() + mock_user.session_id = "anonymous_session" + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.query.return_value.where.return_value.first.return_value = mock_user + + # Act + with app.app_context(): + result = get_user("tenant123", "anonymous_session") + + # Assert + assert result == mock_user + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_create_new_user_when_not_found(self, mock_db, mock_session_class, mock_enduser_class, app: Flask): + """Test creating new user when not found in database""" + # Arrange + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.query.return_value.where.return_value.first.return_value = None + mock_new_user = MagicMock() + mock_enduser_class.return_value = mock_new_user + + # Act + with app.app_context(): + result = get_user("tenant123", "user123") + + # Assert + assert result == mock_new_user + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_use_default_session_id_when_user_id_none( + self, mock_db, mock_session_class, mock_enduser_class, app: Flask + ): + """Test using default session ID when user_id is None""" + # Arrange + mock_user = MagicMock() + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.query.return_value.where.return_value.first.return_value = mock_user + + # Act + with app.app_context(): + result = get_user("tenant123", None) + + # Assert + assert result == mock_user + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_raise_error_on_database_exception( + self, mock_db, mock_session_class, mock_enduser_class, app: Flask + ): + """Test raising ValueError when database operation fails""" + # Arrange + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.query.side_effect = Exception("Database error") + + # Act & Assert + with app.app_context(): + with pytest.raises(ValueError, match="user not found"): + get_user("tenant123", "user123") + + +class TestGetUserTenant: + """Test get_user_tenant decorator""" + + @patch("controllers.inner_api.plugin.wraps.Tenant") + def test_should_inject_tenant_and_user_models(self, mock_tenant_class, app: Flask, monkeypatch): + """Test that decorator injects tenant_model and user_model into kwargs""" + + # Arrange + @get_user_tenant + def protected_view(tenant_model, user_model, **kwargs): + return {"tenant": tenant_model, "user": user_model} + + mock_tenant = MagicMock() + mock_tenant.id = "tenant123" + mock_user = MagicMock() + mock_user.id = "user456" + + # Act + with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}): + monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False) + with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query: + with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user: + mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_get_user.return_value = mock_user + result = protected_view() + + # Assert + assert result["tenant"] == mock_tenant + assert result["user"] == mock_user + + def test_should_raise_error_when_tenant_id_missing(self, app: Flask): + """Test that Pydantic ValidationError is raised when tenant_id is missing from payload""" + + # Arrange + @get_user_tenant + def protected_view(tenant_model, user_model, **kwargs): + return "success" + + # Act & Assert - Pydantic validates payload before manual check + with app.test_request_context(json={"user_id": "user456"}): + with pytest.raises(ValidationError): + protected_view() + + def test_should_raise_error_when_tenant_not_found(self, app: Flask): + """Test that ValueError is raised when tenant is not found""" + + # Arrange + @get_user_tenant + def protected_view(tenant_model, user_model, **kwargs): + return "success" + + # Act & Assert + with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}): + with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query: + mock_query.return_value.where.return_value.first.return_value = None + with pytest.raises(ValueError, match="tenant not found"): + protected_view() + + @patch("controllers.inner_api.plugin.wraps.Tenant") + def test_should_use_default_session_id_when_user_id_empty(self, mock_tenant_class, app: Flask, monkeypatch): + """Test that default session ID is used when user_id is empty string""" + + # Arrange + @get_user_tenant + def protected_view(tenant_model, user_model, **kwargs): + return {"tenant": tenant_model, "user": user_model} + + mock_tenant = MagicMock() + mock_tenant.id = "tenant123" + mock_user = MagicMock() + + # Act - use empty string for user_id to trigger default logic + with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}): + monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False) + with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query: + with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user: + mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_get_user.return_value = mock_user + result = protected_view() + + # Assert + assert result["tenant"] == mock_tenant + assert result["user"] == mock_user + from models.model import DefaultEndUserSessionID + + mock_get_user.assert_called_once_with("tenant123", DefaultEndUserSessionID.DEFAULT_SESSION_ID) + + +class PluginTestPayload: + """Simple test payload class""" + + def __init__(self, data: dict): + self.value = data.get("value") + + @classmethod + def model_validate(cls, data: dict): + return cls(data) + + +class TestPluginData: + """Test plugin_data decorator""" + + def test_should_inject_valid_payload(self, app: Flask): + """Test that valid payload is injected into kwargs""" + + # Arrange + @plugin_data(payload_type=PluginTestPayload) + def protected_view(payload, **kwargs): + return payload + + # Act + with app.test_request_context(json={"value": "test_data"}): + result = protected_view() + + # Assert + assert result.value == "test_data" + + def test_should_raise_error_on_invalid_json(self, app: Flask): + """Test that ValueError is raised when JSON parsing fails""" + + # Arrange + @plugin_data(payload_type=PluginTestPayload) + def protected_view(payload, **kwargs): + return payload + + # Act & Assert - Malformed JSON triggers ValueError + with app.test_request_context(data="not valid json", content_type="application/json"): + with pytest.raises(ValueError): + protected_view() + + def test_should_raise_error_on_invalid_payload(self, app: Flask): + """Test that ValueError is raised when payload validation fails""" + + # Arrange + class InvalidPayload: + @classmethod + def model_validate(cls, data: dict): + raise Exception("Validation failed") + + @plugin_data(payload_type=InvalidPayload) + def protected_view(payload, **kwargs): + return payload + + # Act & Assert + with app.test_request_context(json={"data": "test"}): + with pytest.raises(ValueError, match="invalid payload"): + protected_view() + + def test_should_work_as_parameterized_decorator(self, app: Flask): + """Test that decorator works when used with parentheses""" + + # Arrange + @plugin_data(payload_type=PluginTestPayload) + def protected_view(payload, **kwargs): + return payload + + # Act + with app.test_request_context(json={"value": "parameterized"}): + result = protected_view() + + # Assert + assert result.value == "parameterized" diff --git a/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py new file mode 100644 index 0000000000..883ccdea2c --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py @@ -0,0 +1,309 @@ +""" +Unit tests for inner_api auth decorators +""" + +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import HTTPException + +from configs import dify_config +from controllers.inner_api.wraps import ( + billing_inner_api_only, + enterprise_inner_api_only, + enterprise_inner_api_user_auth, + plugin_inner_api_only, +) + + +class TestBillingInnerApiOnly: + """Test billing_inner_api_only decorator""" + + def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask): + """Test that valid API key allows access when INNER_API is enabled""" + + # Arrange + @billing_inner_api_only + def protected_view(): + return "success" + + # Act + with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + result = protected_view() + + # Assert + assert result == "success" + + def test_should_return_404_when_inner_api_disabled(self, app: Flask): + """Test that 404 is returned when INNER_API is disabled""" + + # Arrange + @billing_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(): + with patch.object(dify_config, "INNER_API", False): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 404 + + def test_should_return_401_when_api_key_missing(self, app: Flask): + """Test that 401 is returned when X-Inner-Api-Key header is missing""" + + # Arrange + @billing_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + def test_should_return_401_when_api_key_invalid(self, app: Flask): + """Test that 401 is returned when X-Inner-Api-Key header is invalid""" + + # Arrange + @billing_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + +class TestEnterpriseInnerApiOnly: + """Test enterprise_inner_api_only decorator""" + + def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask): + """Test that valid API key allows access when INNER_API is enabled""" + + # Arrange + @enterprise_inner_api_only + def protected_view(): + return "success" + + # Act + with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + result = protected_view() + + # Assert + assert result == "success" + + def test_should_return_404_when_inner_api_disabled(self, app: Flask): + """Test that 404 is returned when INNER_API is disabled""" + + # Arrange + @enterprise_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(): + with patch.object(dify_config, "INNER_API", False): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 404 + + def test_should_return_401_when_api_key_missing(self, app: Flask): + """Test that 401 is returned when X-Inner-Api-Key header is missing""" + + # Arrange + @enterprise_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + def test_should_return_401_when_api_key_invalid(self, app: Flask): + """Test that 401 is returned when X-Inner-Api-Key header is invalid""" + + # Arrange + @enterprise_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + +class TestEnterpriseInnerApiUserAuth: + """Test enterprise_inner_api_user_auth decorator for HMAC-based user authentication""" + + def test_should_pass_through_when_inner_api_disabled(self, app: Flask): + """Test that request passes through when INNER_API is disabled""" + + # Arrange + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user", "no_user") + + # Act + with app.test_request_context(): + with patch.object(dify_config, "INNER_API", False): + result = protected_view() + + # Assert + assert result == "no_user" + + def test_should_pass_through_when_authorization_header_missing(self, app: Flask): + """Test that request passes through when Authorization header is missing""" + + # Arrange + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user", "no_user") + + # Act + with app.test_request_context(headers={}): + with patch.object(dify_config, "INNER_API", True): + result = protected_view() + + # Assert + assert result == "no_user" + + def test_should_pass_through_when_authorization_format_invalid(self, app: Flask): + """Test that request passes through when Authorization format is invalid (no colon)""" + + # Arrange + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user", "no_user") + + # Act + with app.test_request_context(headers={"Authorization": "invalid_format"}): + with patch.object(dify_config, "INNER_API", True): + result = protected_view() + + # Assert + assert result == "no_user" + + def test_should_pass_through_when_hmac_signature_invalid(self, app: Flask): + """Test that request passes through when HMAC signature is invalid""" + + # Arrange + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user", "no_user") + + # Act - use wrong signature + with app.test_request_context( + headers={"Authorization": "Bearer user123:wrong_signature", "X-Inner-Api-Key": "valid_key"} + ): + with patch.object(dify_config, "INNER_API", True): + result = protected_view() + + # Assert + assert result == "no_user" + + def test_should_inject_user_when_hmac_signature_valid(self, app: Flask): + """Test that user is injected when HMAC signature is valid""" + # Arrange + from base64 import b64encode + from hashlib import sha1 + from hmac import new as hmac_new + + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user") + + # Calculate valid HMAC signature + user_id = "user123" + inner_api_key = "valid_key" + data_to_sign = f"DIFY {user_id}" + signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) + valid_signature = b64encode(signature.digest()).decode("utf-8") + + # Create mock user + mock_user = MagicMock() + mock_user.id = user_id + + # Act + with app.test_request_context( + headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key} + ): + with patch.object(dify_config, "INNER_API", True): + with patch("controllers.inner_api.wraps.db.session.query") as mock_query: + mock_query.return_value.where.return_value.first.return_value = mock_user + result = protected_view() + + # Assert + assert result == mock_user + + +class TestPluginInnerApiOnly: + """Test plugin_inner_api_only decorator""" + + def test_should_allow_when_plugin_daemon_key_set_and_valid_key(self, app: Flask): + """Test that valid API key allows access when PLUGIN_DAEMON_KEY is set""" + + # Arrange + @plugin_inner_api_only + def protected_view(): + return "success" + + # Act + with app.test_request_context(headers={"X-Inner-Api-Key": "valid_plugin_key"}): + with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"): + with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"): + result = protected_view() + + # Assert + assert result == "success" + + def test_should_return_404_when_plugin_daemon_key_not_set(self, app: Flask): + """Test that 404 is returned when PLUGIN_DAEMON_KEY is not set""" + + # Arrange + @plugin_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(): + with patch.object(dify_config, "PLUGIN_DAEMON_KEY", ""): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 404 + + def test_should_return_404_when_api_key_invalid(self, app: Flask): + """Test that 404 is returned when X-Inner-Api-Key header is invalid (note: returns 404, not 401)""" + + # Arrange + @plugin_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}): + with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"): + with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 404 diff --git a/api/tests/unit_tests/controllers/inner_api/test_mail.py b/api/tests/unit_tests/controllers/inner_api/test_mail.py new file mode 100644 index 0000000000..c2ca35693e --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/test_mail.py @@ -0,0 +1,206 @@ +""" +Unit tests for inner_api mail module +""" + +from unittest.mock import patch + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.inner_api.mail import ( + BaseMail, + BillingMail, + EnterpriseMail, + InnerMailPayload, +) + + +class TestInnerMailPayload: + """Test InnerMailPayload Pydantic model""" + + def test_valid_payload_with_all_fields(self): + """Test valid payload with all fields passes validation""" + data = { + "to": ["test@example.com"], + "subject": "Test Subject", + "body": "Test Body", + "substitutions": {"key": "value"}, + } + payload = InnerMailPayload.model_validate(data) + assert payload.to == ["test@example.com"] + assert payload.subject == "Test Subject" + assert payload.body == "Test Body" + assert payload.substitutions == {"key": "value"} + + def test_valid_payload_without_substitutions(self): + """Test valid payload without optional substitutions""" + data = { + "to": ["test@example.com"], + "subject": "Test Subject", + "body": "Test Body", + } + payload = InnerMailPayload.model_validate(data) + assert payload.to == ["test@example.com"] + assert payload.subject == "Test Subject" + assert payload.body == "Test Body" + assert payload.substitutions is None + + def test_empty_to_list_fails_validation(self): + """Test that empty 'to' list fails validation due to min_length=1""" + data = { + "to": [], + "subject": "Test Subject", + "body": "Test Body", + } + with pytest.raises(ValidationError): + InnerMailPayload.model_validate(data) + + def test_multiple_recipients_allowed(self): + """Test that multiple recipients are allowed""" + data = { + "to": ["user1@example.com", "user2@example.com"], + "subject": "Test Subject", + "body": "Test Body", + } + payload = InnerMailPayload.model_validate(data) + assert len(payload.to) == 2 + assert "user1@example.com" in payload.to + assert "user2@example.com" in payload.to + + def test_missing_to_field_fails_validation(self): + """Test that missing 'to' field fails validation""" + data = { + "subject": "Test Subject", + "body": "Test Body", + } + with pytest.raises(ValidationError): + InnerMailPayload.model_validate(data) + + def test_missing_subject_fails_validation(self): + """Test that missing 'subject' field fails validation""" + data = { + "to": ["test@example.com"], + "body": "Test Body", + } + with pytest.raises(ValidationError): + InnerMailPayload.model_validate(data) + + def test_missing_body_fails_validation(self): + """Test that missing 'body' field fails validation""" + data = { + "to": ["test@example.com"], + "subject": "Test Subject", + } + with pytest.raises(ValidationError): + InnerMailPayload.model_validate(data) + + +class TestBaseMail: + """Test BaseMail API endpoint""" + + @pytest.fixture + def api_instance(self): + """Create BaseMail API instance""" + return BaseMail() + + @patch("controllers.inner_api.mail.send_inner_email_task") + def test_post_sends_email_task(self, mock_task, api_instance, app: Flask): + """Test that POST sends inner email task""" + # Arrange + mock_task.delay.return_value = None + + # Act + with app.test_request_context( + json={ + "to": ["test@example.com"], + "subject": "Test Subject", + "body": "Test Body", + } + ): + with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns: + mock_ns.payload = { + "to": ["test@example.com"], + "subject": "Test Subject", + "body": "Test Body", + } + result = api_instance.post() + + # Assert + assert result == ({"message": "success"}, 200) + mock_task.delay.assert_called_once_with( + to=["test@example.com"], + subject="Test Subject", + body="Test Body", + substitutions=None, + ) + + @patch("controllers.inner_api.mail.send_inner_email_task") + def test_post_with_substitutions(self, mock_task, api_instance, app: Flask): + """Test that POST sends email with substitutions""" + # Arrange + mock_task.delay.return_value = None + + # Act + with app.test_request_context(): + with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns: + mock_ns.payload = { + "to": ["test@example.com"], + "subject": "Hello {{name}}", + "body": "Welcome {{name}}!", + "substitutions": {"name": "John"}, + } + result = api_instance.post() + + # Assert + assert result == ({"message": "success"}, 200) + mock_task.delay.assert_called_once_with( + to=["test@example.com"], + subject="Hello {{name}}", + body="Welcome {{name}}!", + substitutions={"name": "John"}, + ) + + +class TestEnterpriseMail: + """Test EnterpriseMail API endpoint""" + + @pytest.fixture + def api_instance(self): + """Create EnterpriseMail API instance""" + return EnterpriseMail() + + def test_has_enterprise_inner_api_only_decorator(self, api_instance): + """Test that EnterpriseMail has enterprise_inner_api_only decorator""" + # Check method_decorators + from controllers.inner_api.wraps import enterprise_inner_api_only + + assert enterprise_inner_api_only in api_instance.method_decorators + + def test_has_setup_required_decorator(self, api_instance): + """Test that EnterpriseMail has setup_required decorator""" + # Check by decorator name instead of object reference + decorator_names = [d.__name__ for d in api_instance.method_decorators] + assert "setup_required" in decorator_names + + +class TestBillingMail: + """Test BillingMail API endpoint""" + + @pytest.fixture + def api_instance(self): + """Create BillingMail API instance""" + return BillingMail() + + def test_has_billing_inner_api_only_decorator(self, api_instance): + """Test that BillingMail has billing_inner_api_only decorator""" + # Check method_decorators + from controllers.inner_api.wraps import billing_inner_api_only + + assert billing_inner_api_only in api_instance.method_decorators + + def test_has_setup_required_decorator(self, api_instance): + """Test that BillingMail has setup_required decorator""" + # Check by decorator name instead of object reference + decorator_names = [d.__name__ for d in api_instance.method_decorators] + assert "setup_required" in decorator_names diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/__init__.py b/api/tests/unit_tests/controllers/inner_api/workspace/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py new file mode 100644 index 0000000000..4fbf0f7125 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py @@ -0,0 +1,184 @@ +""" +Unit tests for inner_api workspace module + +Tests Pydantic model validation and endpoint handler logic. +Auth/setup decorators are tested separately in test_auth_wraps.py; +handler tests use inspect.unwrap() to bypass them and focus on business logic. +""" + +import inspect +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.inner_api.workspace.workspace import ( + EnterpriseWorkspace, + EnterpriseWorkspaceNoOwnerEmail, + WorkspaceCreatePayload, + WorkspaceOwnerlessPayload, +) + + +class TestWorkspaceCreatePayload: + """Test WorkspaceCreatePayload Pydantic model validation""" + + def test_valid_payload(self): + """Test valid payload with all fields passes validation""" + data = { + "name": "My Workspace", + "owner_email": "owner@example.com", + } + payload = WorkspaceCreatePayload.model_validate(data) + assert payload.name == "My Workspace" + assert payload.owner_email == "owner@example.com" + + def test_missing_name_fails_validation(self): + """Test that missing name fails validation""" + data = {"owner_email": "owner@example.com"} + with pytest.raises(ValidationError) as exc_info: + WorkspaceCreatePayload.model_validate(data) + assert "name" in str(exc_info.value) + + def test_missing_owner_email_fails_validation(self): + """Test that missing owner_email fails validation""" + data = {"name": "My Workspace"} + with pytest.raises(ValidationError) as exc_info: + WorkspaceCreatePayload.model_validate(data) + assert "owner_email" in str(exc_info.value) + + +class TestWorkspaceOwnerlessPayload: + """Test WorkspaceOwnerlessPayload Pydantic model validation""" + + def test_valid_payload(self): + """Test valid payload with name passes validation""" + data = {"name": "My Workspace"} + payload = WorkspaceOwnerlessPayload.model_validate(data) + assert payload.name == "My Workspace" + + def test_missing_name_fails_validation(self): + """Test that missing name fails validation""" + data = {} + with pytest.raises(ValidationError) as exc_info: + WorkspaceOwnerlessPayload.model_validate(data) + assert "name" in str(exc_info.value) + + +class TestEnterpriseWorkspace: + """Test EnterpriseWorkspace API endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py) + and exercise the core business logic directly. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseWorkspace() + + def test_has_post_method(self, api_instance): + """Test that EnterpriseWorkspace has post method""" + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.workspace.workspace.tenant_was_created") + @patch("controllers.inner_api.workspace.workspace.TenantService") + @patch("controllers.inner_api.workspace.workspace.db") + def test_post_creates_workspace_with_owner(self, mock_db, mock_tenant_svc, mock_event, api_instance, app: Flask): + """Test that post() creates a workspace and assigns the owner account""" + # Arrange + mock_account = MagicMock() + mock_account.email = "owner@example.com" + mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + + now = datetime(2025, 1, 1, 12, 0, 0) + mock_tenant = MagicMock() + mock_tenant.id = "tenant-id" + mock_tenant.name = "My Workspace" + mock_tenant.plan = "sandbox" + mock_tenant.status = "normal" + mock_tenant.created_at = now + mock_tenant.updated_at = now + mock_tenant_svc.create_tenant.return_value = mock_tenant + + # Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py) + unwrapped_post = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns: + mock_ns.payload = {"name": "My Workspace", "owner_email": "owner@example.com"} + result = unwrapped_post(api_instance) + + # Assert + assert result["message"] == "enterprise workspace created." + assert result["tenant"]["id"] == "tenant-id" + assert result["tenant"]["name"] == "My Workspace" + mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True) + mock_tenant_svc.create_tenant_member.assert_called_once_with(mock_tenant, mock_account, role="owner") + mock_event.send.assert_called_once_with(mock_tenant) + + @patch("controllers.inner_api.workspace.workspace.db") + def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask): + """Test that post() returns 404 when the owner account does not exist""" + # Arrange + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + unwrapped_post = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns: + mock_ns.payload = {"name": "My Workspace", "owner_email": "missing@example.com"} + result = unwrapped_post(api_instance) + + # Assert + assert result == ({"message": "owner account not found."}, 404) + + +class TestEnterpriseWorkspaceNoOwnerEmail: + """Test EnterpriseWorkspaceNoOwnerEmail API endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py) + and exercise the core business logic directly. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseWorkspaceNoOwnerEmail() + + def test_has_post_method(self, api_instance): + """Test that endpoint has post method""" + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.workspace.workspace.tenant_was_created") + @patch("controllers.inner_api.workspace.workspace.TenantService") + def test_post_creates_ownerless_workspace(self, mock_tenant_svc, mock_event, api_instance, app: Flask): + """Test that post() creates a workspace without an owner and returns expected fields""" + # Arrange + now = datetime(2025, 1, 1, 12, 0, 0) + mock_tenant = MagicMock() + mock_tenant.id = "tenant-id" + mock_tenant.name = "My Workspace" + mock_tenant.encrypt_public_key = "pub-key" + mock_tenant.plan = "sandbox" + mock_tenant.status = "normal" + mock_tenant.custom_config = None + mock_tenant.created_at = now + mock_tenant.updated_at = now + mock_tenant_svc.create_tenant.return_value = mock_tenant + + # Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py) + unwrapped_post = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns: + mock_ns.payload = {"name": "My Workspace"} + result = unwrapped_post(api_instance) + + # Assert + assert result["message"] == "enterprise workspace created." + assert result["tenant"]["id"] == "tenant-id" + assert result["tenant"]["encrypt_public_key"] == "pub-key" + assert result["tenant"]["custom_config"] == {} + mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True) + mock_event.send.assert_called_once_with(mock_tenant) diff --git a/api/tests/unit_tests/controllers/web/test_message_list.py b/api/tests/unit_tests/controllers/web/test_message_list.py index 1c096bfbcf..2bb425cdba 100644 --- a/api/tests/unit_tests/controllers/web/test_message_list.py +++ b/api/tests/unit_tests/controllers/web/test_message_list.py @@ -137,7 +137,7 @@ def test_message_list_mapping(app: Flask) -> None: {"id": "file-dict", "filename": "a.txt", "type": "file", "transfer_method": "local"}, message_file_obj, ], - status="success", + status="normal", error=None, message_metadata_dict={"meta": "value"}, extra_contents=[ diff --git a/api/tests/unit_tests/core/agent/conftest.py b/api/tests/unit_tests/core/agent/conftest.py new file mode 100644 index 0000000000..a2aa501720 --- /dev/null +++ b/api/tests/unit_tests/core/agent/conftest.py @@ -0,0 +1,80 @@ +import pytest + + +class DummyTool: + def __init__(self, name): + self.name = name + + +class DummyPromptEntity: + def __init__(self, first_prompt): + self.first_prompt = first_prompt + + +class DummyAgentConfig: + def __init__(self, prompt_entity=None): + self.prompt = prompt_entity + + +class DummyAppConfig: + def __init__(self, agent=None): + self.agent = agent + + +class DummyScratchpadUnit: + def __init__( + self, + final=False, + thought=None, + action_str=None, + observation=None, + agent_response=None, + ): + self._final = final + self.thought = thought + self.action_str = action_str + self.observation = observation + self.agent_response = agent_response + + def is_final(self): + return self._final + + +@pytest.fixture +def dummy_tool_factory(): + def _factory(name): + return DummyTool(name) + + return _factory + + +@pytest.fixture +def dummy_prompt_entity_factory(): + def _factory(first_prompt): + return DummyPromptEntity(first_prompt) + + return _factory + + +@pytest.fixture +def dummy_agent_config_factory(): + def _factory(prompt_entity=None): + return DummyAgentConfig(prompt_entity) + + return _factory + + +@pytest.fixture +def dummy_app_config_factory(): + def _factory(agent=None): + return DummyAppConfig(agent) + + return _factory + + +@pytest.fixture +def dummy_scratchpad_unit_factory(): + def _factory(**kwargs): + return DummyScratchpadUnit(**kwargs) + + return _factory diff --git a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py index ba8c903f65..9073ae1044 100644 --- a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py +++ b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py @@ -1,70 +1,255 @@ +"""Unit tests for CotAgentOutputParser. + +Verifies expected parsing behavior for streaming content and JSON payloads, +including edge cases such as empty/non-string content and malformed JSON. +Assumes lightweight fixtures (SimpleNamespace/MagicMock) stand in for real +model output structures. Implementation under test: +core.agent.output_parser.cot_output_parser.CotAgentOutputParser. +""" + import json -from collections.abc import Generator +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest -from core.agent.entities import AgentScratchpadUnit from core.agent.output_parser.cot_output_parser import CotAgentOutputParser -from dify_graph.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta -def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]: - for i in range(len(text)): - yield LLMResultChunk( - model="model", - prompt_messages=[], - delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text[i], tool_calls=[])), +@pytest.fixture +def mock_action_class(mocker): + mock_action = MagicMock() + mocker.patch( + "core.agent.output_parser.cot_output_parser.AgentScratchpadUnit.Action", + mock_action, + ) + return mock_action + + +@pytest.fixture +def usage_dict(): + return {} + + +@pytest.fixture +def make_chunk(): + def _make_chunk(content=None, usage=None): + delta = SimpleNamespace( + message=SimpleNamespace(content=content), + usage=usage, ) + return SimpleNamespace(delta=delta) + + return _make_chunk -def test_cot_output_parser(): - test_cases = [ - { - "input": 'Through: abc\nAction: ```{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}```', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # code block with json - { - "input": 'Through: abc\nAction: ```json\n{"action": "Final Answer", "action_input": "```echarts\n {' - '}\n```"}```', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # code block with JSON - { - "input": 'Through: abc\nAction: ```JSON\n{"action": "Final Answer", "action_input": "```echarts\n {' - '}\n```"}```', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # list - { - "input": 'Through: abc\nAction: ```[{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}]```', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # no code block - { - "input": 'Through: abc\nAction: {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # no code block and json - {"input": "Through: abc\nAction: efg", "action": {}, "output": "Through: abc\n efg"}, - ] +# ============================================================ +# Test Suite +# ============================================================ - parser = CotAgentOutputParser() - usage_dict = {} - for test_case in test_cases: - # mock llm_response as a generator by text - llm_response: Generator[LLMResultChunk, None, None] = mock_llm_response(test_case["input"]) - results = parser.handle_react_stream_output(llm_response, usage_dict) - output = "" - for result in results: - if isinstance(result, str): - output += result - elif isinstance(result, AgentScratchpadUnit.Action): - if test_case["action"]: - assert result.to_dict() == test_case["action"] - output += json.dumps(result.to_dict()) - if test_case["output"]: - assert output == test_case["output"] + +class TestCotAgentOutputParser: + """Validate CotAgentOutputParser streaming + JSON parsing behavior. + + Lifecycle: no explicit setup/teardown; relies on pytest fixtures for + lightweight chunk/action doubles. Invariants: non-string/empty content + yields no output, usage gets recorded when provided, and valid action JSON + results in Action instantiation. Usage: invoke via pytest (e.g., + `pytest -k TestCotAgentOutputParser`). + """ + + # -------------------------------------------------------- + # Basic streaming & usage + # -------------------------------------------------------- + + def test_stream_plain_text(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk("hello world")] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert "".join(result) == "hello world" + + def test_stream_empty_string(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk("")] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert result == [] + + def test_stream_none_content(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk(None)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert result == [] + + @pytest.mark.parametrize("content", [123, 12.5, [], {}, object()]) + def test_non_string_content(self, make_chunk, usage_dict, content) -> None: + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert result == [] + + def test_usage_update(self, make_chunk, usage_dict) -> None: + usage_data = {"tokens": 99} + chunks = [make_chunk("abc", usage=usage_data)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert usage_dict["usage"] == usage_data + + # -------------------------------------------------------- + # JSON parsing (direct + streaming) + # -------------------------------------------------------- + + def test_single_json_action_valid(self, make_chunk, usage_dict, mock_action_class) -> None: + content = '{"action": "search", "input": "query"}' + chunks = [make_chunk(content)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + mock_action_class.assert_called_once_with(action_name="search", action_input="query") + + def test_json_list_unwrap(self, make_chunk, usage_dict, mock_action_class) -> None: + content = '[{"action": "lookup", "input": "abc"}]' + chunks = [make_chunk(content)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc") + + def test_json_missing_fields_returns_string(self, make_chunk, usage_dict) -> None: + content = '{"foo": "bar"}' + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + # Expect the serialized JSON to be yielded as a single element. + assert result == [json.dumps({"foo": "bar"})] + + def test_invalid_json_string_input(self, make_chunk, usage_dict) -> None: + content = "{invalid json}" + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert any("invalid json" in str(r) for r in result) + + def test_json_split_across_chunks(self, make_chunk, usage_dict, mock_action_class) -> None: + chunks = [ + make_chunk('{"action": '), + make_chunk('"multi", '), + make_chunk('"input": "step"}'), + ] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + mock_action_class.assert_called_once_with(action_name="multi", action_input="step") + + def test_unclosed_json_at_end(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk('{"foo": "bar"')] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert all(isinstance(item, str) for item in result) + assert any('{"foo": "bar"' in item for item in result) + + # -------------------------------------------------------- + # Code block JSON extraction + # -------------------------------------------------------- + + def test_code_block_json_valid(self, make_chunk, usage_dict, mock_action_class) -> None: + content = """```json +{"action": "lookup", "input": "abc"} +```""" + chunks = [make_chunk(content)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc") + + def test_code_block_multiple_json(self, make_chunk, usage_dict, mock_action_class) -> None: + # Multiple JSON objects inside single code fence (invalid combined JSON) + # Parser should safely ignore invalid combined block + content = """```json +{"action": "a1", "input": "x"} +{"action": "a2", "input": "y"} +```""" + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + # No valid parsed action expected due to invalid combined JSON + assert mock_action_class.call_count == 0 + assert isinstance(result, list) + + def test_code_block_invalid_json(self, make_chunk, usage_dict) -> None: + content = """```json +{invalid} +```""" + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert result + + def test_unclosed_code_block(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk('```json {"a":1}')] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert all(isinstance(item, str) for item in result) + assert any('```json {"a":1}' in item for item in result) + + # -------------------------------------------------------- + # Action / Thought prefix handling + # -------------------------------------------------------- + + @pytest.mark.parametrize( + "content", + [ + " action: something", + " ACTION: something", + " thought: reasoning", + " THOUGHT: reasoning", + ], + ) + def test_prefix_handling(self, make_chunk, usage_dict, content) -> None: + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + joined = "".join(str(item) for item in result) + expected_word = "something" if "action:" in content.lower() else "reasoning" + assert expected_word in joined + assert "action:" not in joined.lower() + assert "thought:" not in joined.lower() + + def test_prefix_mid_word_yield_delta_branch(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk("xaction: test")] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert "x" in "".join(map(str, result)) + + # -------------------------------------------------------- + # Mixed streaming scenarios + # -------------------------------------------------------- + + def test_text_json_text_mix(self, make_chunk, usage_dict, mock_action_class) -> None: + content = 'start {"action": "mix", "input": "1"} end' + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + # JSON action should be parsed + mock_action_class.assert_called_once() + # Ensure surrounding text is streamed (character-level) + joined = "".join(str(r) for r in result if not isinstance(r, MagicMock)) + assert "start" in joined + assert "end" in joined + + def test_multiple_code_blocks_in_stream(self, make_chunk, usage_dict, mock_action_class) -> None: + content = '```json\n{"action":"a1","input":"x"}\n```middle```json\n{"action":"a2","input":"y"}\n```' + chunks = [make_chunk(content)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert mock_action_class.call_count == 2 + + def test_backtick_noise(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk("text with ` random ` backticks")] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert "text with" in "".join(result) + + # -------------------------------------------------------- + # Boundary & edge inputs + # -------------------------------------------------------- + + @pytest.mark.parametrize( + "content", + [ + "```", + "{", + "}", + "```json", + "action:", + "thought:", + " ", + ], + ) + def test_edge_inputs(self, make_chunk, usage_dict, content) -> None: + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert all(isinstance(item, str) for item in result) + joined = "".join(result) + if content == " ": + assert result == [] or joined == content + if content in {"```", "{", "}", "```json"}: + assert content in joined + if content.lower() in {"action:", "thought:"}: + assert "action:" not in joined.lower() + assert "thought:" not in joined.lower() diff --git a/api/tests/unit_tests/core/agent/strategy/test_base.py b/api/tests/unit_tests/core/agent/strategy/test_base.py new file mode 100644 index 0000000000..83ff79e8a1 --- /dev/null +++ b/api/tests/unit_tests/core/agent/strategy/test_base.py @@ -0,0 +1,174 @@ +from collections.abc import Generator +from unittest.mock import MagicMock + +import pytest + +from core.agent.strategy.base import BaseAgentStrategy + + +class DummyStrategy(BaseAgentStrategy): + """ + Concrete implementation for testing BaseAgentStrategy + """ + + def __init__(self, return_values=None, raise_exception=None): + self.return_values = return_values or [] + self.raise_exception = raise_exception + self.received_args = None + + def _invoke( + self, + params, + user_id, + conversation_id=None, + app_id=None, + message_id=None, + credentials=None, + ) -> Generator: + self.received_args = ( + params, + user_id, + conversation_id, + app_id, + message_id, + credentials, + ) + + if self.raise_exception: + raise self.raise_exception + + yield from self.return_values + + +class TestBaseAgentStrategyInstantiation: + def test_cannot_instantiate_abstract_class(self) -> None: + with pytest.raises(TypeError): + BaseAgentStrategy() + + +class TestBaseAgentStrategyInvoke: + @pytest.fixture + def mock_message(self): + return MagicMock(name="AgentInvokeMessage") + + @pytest.fixture + def mock_credentials(self): + return MagicMock(name="InvokeCredentials") + + @pytest.mark.parametrize( + ("params", "user_id", "conversation_id", "app_id", "message_id"), + [ + ({"key": "value"}, "user1", "conv1", "app1", "msg1"), + ({}, "user2", None, None, None), + ({"a": 1}, "", "", "", ""), + ({"nested": {"x": 1}}, "user3", None, "app3", None), + ], + ) + def test_invoke_success( + self, + mock_message, + mock_credentials, + params, + user_id, + conversation_id, + app_id, + message_id, + ) -> None: + # Arrange + strategy = DummyStrategy(return_values=[mock_message]) + + # Act + result = list( + strategy.invoke( + params=params, + user_id=user_id, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + credentials=mock_credentials, + ) + ) + + # Assert + assert result == [mock_message] + assert strategy.received_args == ( + params, + user_id, + conversation_id, + app_id, + message_id, + mock_credentials, + ) + + def test_invoke_multiple_yields(self, mock_message) -> None: + # Arrange + messages = [mock_message, MagicMock(), MagicMock()] + strategy = DummyStrategy(return_values=messages) + + # Act + result = list(strategy.invoke(params={}, user_id="user")) + + # Assert + assert result == messages + + def test_invoke_empty_generator(self) -> None: + # Arrange + strategy = DummyStrategy(return_values=[]) + + # Act + result = list(strategy.invoke(params={}, user_id="user")) + + # Assert + assert result == [] + + def test_invoke_propagates_exception(self) -> None: + # Arrange + strategy = DummyStrategy(raise_exception=ValueError("failure")) + + # Act & Assert + with pytest.raises(ValueError, match="failure"): + list(strategy.invoke(params={}, user_id="user")) + + @pytest.mark.parametrize( + "invalid_params", + [ + None, + "", + 123, + [], + ], + ) + def test_invoke_invalid_params_type_pass_through(self, invalid_params) -> None: + """ + Base class does not validate types — ensure pass-through behavior + """ + strategy = DummyStrategy(return_values=[]) + + result = list(strategy.invoke(params=invalid_params, user_id="user")) + + assert result == [] + + def test_invoke_none_user_id(self) -> None: + strategy = DummyStrategy(return_values=[]) + + result = list(strategy.invoke(params={}, user_id=None)) + + assert result == [] + + +class TestBaseAgentStrategyGetParameters: + def test_get_parameters_default_empty_list(self) -> None: + strategy = DummyStrategy() + result = strategy.get_parameters() + + assert isinstance(result, list) + assert result == [] + + def test_get_parameters_returns_new_list_each_time(self) -> None: + strategy = DummyStrategy() + + first = strategy.get_parameters() + second = strategy.get_parameters() + + assert first == second == [] + assert first is not second diff --git a/api/tests/unit_tests/core/agent/strategy/test_plugin.py b/api/tests/unit_tests/core/agent/strategy/test_plugin.py new file mode 100644 index 0000000000..e0894f1e90 --- /dev/null +++ b/api/tests/unit_tests/core/agent/strategy/test_plugin.py @@ -0,0 +1,272 @@ +# File: tests/unit_tests/core/agent/strategy/test_plugin.py + +from unittest.mock import MagicMock + +import pytest + +from core.agent.strategy.plugin import PluginAgentStrategy + +# ============================================================ +# Fixtures +# ============================================================ + + +@pytest.fixture +def mock_parameter(): + def _factory(name="param", return_value="initialized"): + param = MagicMock() + param.name = name + param.init_frontend_parameter = MagicMock(return_value=return_value) + return param + + return _factory + + +@pytest.fixture +def mock_declaration(mock_parameter): + param1 = mock_parameter("param1", "init1") + param2 = mock_parameter("param2", "init2") + + identity = MagicMock() + identity.provider = "provider_x" + identity.name = "strategy_x" + + declaration = MagicMock() + declaration.parameters = [param1, param2] + declaration.identity = identity + + return declaration + + +@pytest.fixture +def strategy(mock_declaration): + return PluginAgentStrategy( + tenant_id="tenant_123", + declaration=mock_declaration, + meta_version="v1", + ) + + +# ============================================================ +# Initialization Tests +# ============================================================ + + +class TestPluginAgentStrategyInitialization: + def test_init_sets_attributes(self, mock_declaration) -> None: + strategy = PluginAgentStrategy( + tenant_id="tenant_test", + declaration=mock_declaration, + meta_version="meta_v", + ) + + assert strategy.tenant_id == "tenant_test" + assert strategy.declaration == mock_declaration + assert strategy.meta_version == "meta_v" + + def test_init_meta_version_none(self, mock_declaration) -> None: + strategy = PluginAgentStrategy( + tenant_id="tenant_test", + declaration=mock_declaration, + meta_version=None, + ) + + assert strategy.meta_version is None + + +# ============================================================ +# get_parameters Tests +# ============================================================ + + +class TestGetParameters: + def test_get_parameters_returns_parameters(self, strategy, mock_declaration) -> None: + result = strategy.get_parameters() + assert result == mock_declaration.parameters + + +# ============================================================ +# initialize_parameters Tests +# ============================================================ + + +class TestInitializeParameters: + def test_initialize_parameters_success(self, strategy, mock_declaration) -> None: + params = {"param1": "value1"} + + result = strategy.initialize_parameters(params.copy()) + + assert result["param1"] == "init1" + assert result["param2"] == "init2" + + mock_declaration.parameters[0].init_frontend_parameter.assert_called_once_with("value1") + mock_declaration.parameters[1].init_frontend_parameter.assert_called_once_with(None) + + @pytest.mark.parametrize( + "input_params", + [ + {}, + {"param1": None}, + {"param1": ""}, + {"param1": 0}, + {"param1": []}, + {"param1": {}, "param2": "value"}, + ], + ) + def test_initialize_parameters_edge_cases(self, strategy, input_params) -> None: + result = strategy.initialize_parameters(input_params.copy()) + + for param in strategy.declaration.parameters: + assert param.name in result + + def test_initialize_parameters_invalid_input_type(self, strategy) -> None: + with pytest.raises(AttributeError): + strategy.initialize_parameters(None) + + +# ============================================================ +# _invoke Tests +# ============================================================ + + +class TestInvoke: + def test_invoke_success_all_arguments(self, strategy, mocker) -> None: + mock_manager = MagicMock() + mock_manager.invoke = MagicMock(return_value=iter(["msg1", "msg2"])) + + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=mock_manager, + ) + + mock_convert = mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + return_value={"converted": True}, + ) + + result = list( + strategy._invoke( + params={"param1": "value"}, + user_id="user_1", + conversation_id="conv_1", + app_id="app_1", + message_id="msg_1", + credentials=None, + ) + ) + + assert result == ["msg1", "msg2"] + mock_convert.assert_called_once() + mock_manager.invoke.assert_called_once() + + call_kwargs = mock_manager.invoke.call_args.kwargs + assert call_kwargs["tenant_id"] == "tenant_123" + assert call_kwargs["user_id"] == "user_1" + assert call_kwargs["agent_provider"] == "provider_x" + assert call_kwargs["agent_strategy"] == "strategy_x" + assert call_kwargs["agent_params"] == {"converted": True} + assert call_kwargs["conversation_id"] == "conv_1" + assert call_kwargs["app_id"] == "app_1" + assert call_kwargs["message_id"] == "msg_1" + assert call_kwargs["context"] is not None + + def test_invoke_with_credentials(self, strategy, mocker) -> None: + mock_manager = MagicMock() + mock_manager.invoke = MagicMock(return_value=iter([])) + + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=mock_manager, + ) + + mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + return_value={}, + ) + + # Patch PluginInvokeContext to bypass pydantic validation + mock_context = MagicMock() + mocker.patch( + "core.agent.strategy.plugin.PluginInvokeContext", + return_value=mock_context, + ) + + credentials = MagicMock() + + result = list( + strategy._invoke( + params={}, + user_id="user_1", + credentials=credentials, + ) + ) + + assert result == [] + mock_manager.invoke.assert_called_once() + + @pytest.mark.parametrize( + ("conversation_id", "app_id", "message_id"), + [ + (None, None, None), + ("conv", None, None), + (None, "app", None), + (None, None, "msg"), + ], + ) + def test_invoke_optional_arguments(self, strategy, mocker, conversation_id, app_id, message_id) -> None: + mock_manager = MagicMock() + mock_manager.invoke = MagicMock(return_value=iter([])) + + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=mock_manager, + ) + + mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + return_value={}, + ) + + result = list( + strategy._invoke( + params={}, + user_id="user_1", + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ) + ) + + assert result == [] + mock_manager.invoke.assert_called_once() + + def test_invoke_convert_raises_exception(self, strategy, mocker) -> None: + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=MagicMock(), + ) + + mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + side_effect=ValueError("conversion failed"), + ) + + with pytest.raises(ValueError): + list(strategy._invoke(params={}, user_id="user_1")) + + def test_invoke_manager_raises_exception(self, strategy, mocker) -> None: + mock_manager = MagicMock() + mock_manager.invoke.side_effect = RuntimeError("invoke failed") + + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=mock_manager, + ) + + mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + return_value={}, + ) + + with pytest.raises(RuntimeError): + list(strategy._invoke(params={}, user_id="user_1")) diff --git a/api/tests/unit_tests/core/agent/test_base_agent_runner.py b/api/tests/unit_tests/core/agent/test_base_agent_runner.py new file mode 100644 index 0000000000..683cc0e36f --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_base_agent_runner.py @@ -0,0 +1,802 @@ +import json +from decimal import Decimal +from unittest.mock import MagicMock + +import pytest + +import core.agent.base_agent_runner as module +from core.agent.base_agent_runner import BaseAgentRunner + +# ========================================================== +# Fixtures +# ========================================================== + + +@pytest.fixture +def mock_db_session(mocker): + session = mocker.MagicMock() + mocker.patch.object(module.db, "session", session) + return session + + +@pytest.fixture +def runner(mocker, mock_db_session): + r = BaseAgentRunner.__new__(BaseAgentRunner) + r.tenant_id = "tenant" + r.user_id = "user" + r.agent_thought_count = 0 + r.message = mocker.MagicMock(id="msg_current", conversation_id="conv1") + r.app_config = mocker.MagicMock() + r.app_config.app_id = "app1" + r.app_config.agent = None + r.dataset_tools = [] + r.application_generate_entity = mocker.MagicMock(invoke_from="test") + r._current_thoughts = [] + return r + + +# ========================================================== +# _repack_app_generate_entity +# ========================================================== + + +class TestRepack: + def test_sets_empty_if_none(self, runner, mocker): + entity = mocker.MagicMock() + entity.app_config.prompt_template.simple_prompt_template = None + result = runner._repack_app_generate_entity(entity) + assert result.app_config.prompt_template.simple_prompt_template == "" + + def test_keeps_existing(self, runner, mocker): + entity = mocker.MagicMock() + entity.app_config.prompt_template.simple_prompt_template = "abc" + result = runner._repack_app_generate_entity(entity) + assert result.app_config.prompt_template.simple_prompt_template == "abc" + + +# ========================================================== +# update_prompt_message_tool +# ========================================================== + + +class TestUpdatePromptTool: + def build_param(self, mocker, **kwargs): + p = mocker.MagicMock() + p.form = kwargs.get("form") + + mock_type = mocker.MagicMock() + mock_type.as_normal_type.return_value = "string" + p.type = mock_type + + p.name = kwargs.get("name", "p1") + p.llm_description = "desc" + p.input_schema = kwargs.get("input_schema") + p.options = kwargs.get("options") + p.required = kwargs.get("required", False) + return p + + def test_skip_non_llm(self, runner, mocker): + tool = mocker.MagicMock() + param = self.build_param(mocker, form="NOT_LLM") + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + assert result.parameters["properties"] == {} + + def test_enum_and_required(self, runner, mocker): + option = mocker.MagicMock(value="opt1") + param = self.build_param( + mocker, + form=module.ToolParameter.ToolParameterForm.LLM, + options=[option], + required=True, + ) + + tool = mocker.MagicMock() + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + assert "p1" in result.parameters["required"] + + def test_skip_file_type_param(self, runner, mocker): + tool = mocker.MagicMock() + param = self.build_param(mocker, form=module.ToolParameter.ToolParameterForm.LLM) + param.type = module.ToolParameter.ToolParameterType.FILE + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + assert result.parameters["properties"] == {} + + def test_duplicate_required_not_duplicated(self, runner, mocker): + tool = mocker.MagicMock() + + param = self.build_param( + mocker, + form=module.ToolParameter.ToolParameterForm.LLM, + required=True, + ) + + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": ["p1"]} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + + assert result.parameters["required"].count("p1") == 1 + + +# ========================================================== +# create_agent_thought +# ========================================================== + + +class TestCreateAgentThought: + def test_with_files(self, runner, mock_db_session, mocker): + mock_thought = mocker.MagicMock(id=10) + mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought) + + result = runner.create_agent_thought("m", "msg", "tool", "input", ["f1"]) + assert result == "10" + assert runner.agent_thought_count == 1 + + def test_without_files(self, runner, mock_db_session, mocker): + mock_thought = mocker.MagicMock(id=11) + mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought) + + result = runner.create_agent_thought("m", "msg", "tool", "input", []) + assert result == "11" + + +# ========================================================== +# save_agent_thought +# ========================================================== + + +class TestSaveAgentThought: + def setup_agent(self, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1;tool2" + agent.tool_labels = {} + agent.thought = "" + return agent + + def test_not_found(self, runner, mock_db_session): + mock_db_session.scalar.return_value = None + with pytest.raises(ValueError): + runner.save_agent_thought("id", None, None, None, None, None, None, [], None) + + def test_full_update(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + mock_db_session.scalar.return_value = agent + + mock_label = mocker.MagicMock() + mock_label.to_dict.return_value = {"en_US": "label"} + mocker.patch.object(module.ToolManager, "get_tool_label", return_value=mock_label) + + usage = mocker.MagicMock( + prompt_tokens=1, + prompt_price_unit=Decimal("0.1"), + prompt_unit_price=Decimal("0.1"), + completion_tokens=2, + completion_price_unit=Decimal("0.2"), + completion_unit_price=Decimal("0.2"), + total_tokens=3, + total_price=Decimal("0.3"), + ) + + runner.save_agent_thought( + "id", + "tool1;tool2", + {"a": 1}, + "thought", + {"b": 2}, + {"meta": 1}, + "answer", + ["f1"], + usage, + ) + + assert agent.answer == "answer" + assert agent.tokens == 3 + assert "tool1" in json.loads(agent.tool_labels_str) + + def test_label_fallback_when_none(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + agent.tool = "unknown_tool" + mock_db_session.scalar.return_value = agent + mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None) + + runner.save_agent_thought("id", None, None, None, None, None, None, [], None) + labels = json.loads(agent.tool_labels_str) + assert "unknown_tool" in labels + + def test_json_failure_paths(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + mock_db_session.scalar.return_value = agent + + bad_obj = MagicMock() + bad_obj.__str__.return_value = "bad" + + runner.save_agent_thought( + "id", + None, + bad_obj, + None, + bad_obj, + bad_obj, + None, + [], + None, + ) + + assert mock_db_session.commit.called + + def test_messages_ids_none(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + mock_db_session.scalar.return_value = agent + runner.save_agent_thought("id", None, None, None, None, None, None, None, None) + assert mock_db_session.commit.called + + def test_success_dict_serialization(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + mock_db_session.scalar.return_value = agent + + runner.save_agent_thought( + "id", + None, + {"a": 1}, + None, + {"b": 2}, + None, + None, + [], + None, + ) + + assert isinstance(agent.tool_input, str) + assert isinstance(agent.observation, str) + + +# ========================================================== +# organize_agent_user_prompt +# ========================================================== + + +class TestOrganizeUserPrompt: + def test_no_files(self, runner, mock_db_session, mocker): + mock_db_session.scalars.return_value.all.return_value = [] + msg = mocker.MagicMock(id="1", query="hello", app_model_config=None) + result = runner.organize_agent_user_prompt(msg) + assert result.content == "hello" + + def test_with_files_no_config(self, runner, mock_db_session, mocker): + mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()] + msg = mocker.MagicMock(id="1", query="hello", app_model_config=None) + result = runner.organize_agent_user_prompt(msg) + assert result.content == "hello" + + def test_image_detail_low_fallback(self, runner, mock_db_session, mocker): + mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()] + file_config = mocker.MagicMock() + file_config.image_config = mocker.MagicMock(detail=None) + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config) + mocker.patch.object(module.file_factory, "build_from_message_files", return_value=[]) + + msg = mocker.MagicMock(id="1", query="hello") + msg.app_model_config.to_dict.return_value = {} + + result = runner.organize_agent_user_prompt(msg) + assert result.content == "hello" + + +# ========================================================== +# organize_agent_history +# ========================================================== + + +class TestOrganizeHistory: + def test_empty(self, runner, mock_db_session, mocker): + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [] + mocker.patch.object(module, "extract_thread_messages", return_value=[]) + result = runner.organize_agent_history([]) + assert result == [] + + def test_with_answer_only(self, runner, mock_db_session, mocker): + msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None) + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + result = runner.organize_agent_history([]) + assert any(isinstance(x, module.AssistantPromptMessage) for x in result) + + def test_skip_current_message(self, runner, mock_db_session, mocker): + msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None) + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + result = runner.organize_agent_history([]) + assert result == [] + + def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock( + tool="tool1", + tool_input="invalid", + observation="invalid", + thought="thinking", + ) + msg = mocker.MagicMock(id="m2", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + mocker.patch("uuid.uuid4", return_value="uuid") + + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + def test_empty_tool_name_split(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock(tool=";", thought="thinking") + msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + def test_valid_json_tool_flow(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock( + tool="tool1", + tool_input=json.dumps({"tool1": {"x": 1}}), + observation=json.dumps({"tool1": "obs"}), + thought="thinking", + ) + + msg = mocker.MagicMock( + id="m100", + agent_thoughts=[thought], + answer=None, + app_model_config=None, + ) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + mocker.patch("uuid.uuid4", return_value="uuid") + + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + +# ========================================================== +# _convert_tool_to_prompt_message_tool (new coverage) +# ========================================================== + + +class TestConvertToolToPromptMessageTool: + def test_basic_conversion(self, runner, mocker): + tool = mocker.MagicMock(tool_name="tool1") + + runtime_param = mocker.MagicMock() + runtime_param.form = module.ToolParameter.ToolParameterForm.LLM + runtime_param.name = "param1" + runtime_param.llm_description = "desc" + runtime_param.required = True + runtime_param.input_schema = None + runtime_param.options = None + + mock_type = mocker.MagicMock() + mock_type.as_normal_type.return_value = "string" + runtime_param.type = mock_type + + tool_entity = mocker.MagicMock() + tool_entity.entity.description.llm = "desc" + tool_entity.get_merged_runtime_parameters.return_value = [runtime_param] + + mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool) + assert entity == tool_entity + + def test_full_conversion_multiple_params(self, runner, mocker): + tool = mocker.MagicMock(tool_name="tool1") + + # LLM param with input_schema override + param1 = mocker.MagicMock() + param1.form = module.ToolParameter.ToolParameterForm.LLM + param1.name = "p1" + param1.llm_description = "desc" + param1.required = True + param1.input_schema = {"type": "integer"} + param1.options = None + param1.type = mocker.MagicMock() + + # SYSTEM_FILES param should be skipped + param2 = mocker.MagicMock() + param2.form = module.ToolParameter.ToolParameterForm.LLM + param2.name = "file_param" + param2.type = module.ToolParameter.ToolParameterType.SYSTEM_FILES + + tool_entity = mocker.MagicMock() + tool_entity.entity.description.llm = "desc" + tool_entity.get_merged_runtime_parameters.return_value = [param1, param2] + + mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool) + + assert entity == tool_entity + + +# ========================================================== +# _init_prompt_tools additional branches +# ========================================================== + + +class TestInitPromptToolsExtended: + def test_agent_tool_branch(self, runner, mocker): + agent_tool = mocker.MagicMock(tool_name="agent_tool") + runner.app_config.agent = mocker.MagicMock(tools=[agent_tool]) + mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity")) + + tools, prompts = runner._init_prompt_tools() + assert "agent_tool" in tools + + def test_exception_in_conversion(self, runner, mocker): + agent_tool = mocker.MagicMock(tool_name="bad_tool") + runner.app_config.agent = mocker.MagicMock(tools=[agent_tool]) + mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception) + + tools, prompts = runner._init_prompt_tools() + assert tools == {} + + +# ========================================================== +# Additional Coverage Tests (DO NOT MODIFY EXISTING TESTS) +# ========================================================== + + +class TestAdditionalCoverage: + def test_update_prompt_with_input_schema(self, runner, mocker): + tool = mocker.MagicMock() + + param = mocker.MagicMock() + param.form = module.ToolParameter.ToolParameterForm.LLM + param.name = "p1" + param.required = False + param.llm_description = "desc" + param.options = None + param.input_schema = {"type": "number"} + + mock_type = mocker.MagicMock() + mock_type.as_normal_type.return_value = "string" + param.type = mock_type + + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + assert result.parameters["properties"]["p1"]["type"] == "number" + + def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1" + agent.tool_labels = {"tool1": {"en_US": "existing"}} + agent.thought = "" + mock_db_session.scalar.return_value = agent + + runner.save_agent_thought("id", None, None, None, None, None, None, [], None) + labels = json.loads(agent.tool_labels_str) + assert labels["tool1"]["en_US"] == "existing" + + def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1" + agent.tool_labels = {} + agent.thought = "" + mock_db_session.scalar.return_value = agent + + runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None) + assert agent.tool_meta_str == "meta_string" + + def test_convert_dataset_retriever_tool(self, runner, mocker): + ds_tool = mocker.MagicMock() + ds_tool.entity.identity.name = "ds" + ds_tool.entity.description.llm = "desc" + + param = mocker.MagicMock() + param.name = "query" + param.llm_description = "desc" + param.required = True + + ds_tool.get_runtime_parameters.return_value = [param] + + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool) + assert prompt is not None + + def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker): + mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()] + + file_config = mocker.MagicMock() + file_config.image_config = mocker.MagicMock(detail=None) + + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config) + mocker.patch.object(module.file_factory, "build_from_message_files", return_value=["file1"]) + mocker.patch.object(module.file_manager, "to_prompt_message_content", return_value=mocker.MagicMock()) + + mocker.patch.object(module, "UserPromptMessage", side_effect=lambda **kw: MagicMock(**kw)) + mocker.patch.object(module, "TextPromptMessageContent", side_effect=lambda **kw: MagicMock(**kw)) + + msg = mocker.MagicMock(id="1", query="hello") + msg.app_model_config.to_dict.return_value = {} + + result = runner.organize_agent_user_prompt(msg) + assert result is not None + + def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock(tool=None, thought="thinking") + msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock( + tool="tool1;tool2", + tool_input=json.dumps({"tool1": {}, "tool2": {}}), + observation=json.dumps({"tool1": "o1", "tool2": "o2"}), + thought="thinking", + ) + msg = mocker.MagicMock(id="m4", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + mocker.patch("uuid.uuid4", return_value="uuid") + + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + # ================= Additional Surgical Coverage ================= + + def test_convert_tool_select_enum_branch(self, runner, mocker): + tool = mocker.MagicMock(tool_name="tool1") + + param = mocker.MagicMock() + param.form = module.ToolParameter.ToolParameterForm.LLM + param.name = "select_param" + param.required = True + param.llm_description = "desc" + param.input_schema = None + + option1 = mocker.MagicMock(value="A") + option2 = mocker.MagicMock(value="B") + param.options = [option1, option2] + param.type = module.ToolParameter.ToolParameterType.SELECT + + tool_entity = mocker.MagicMock() + tool_entity.entity.description.llm = "desc" + tool_entity.get_merged_runtime_parameters.return_value = [param] + + mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool) + assert prompt_tool is not None + + +class TestConvertDatasetRetrieverTool: + def test_required_param_added(self, runner, mocker): + ds_tool = mocker.MagicMock() + ds_tool.entity.identity.name = "ds" + ds_tool.entity.description.llm = "desc" + + param = mocker.MagicMock() + param.name = "query" + param.llm_description = "desc" + param.required = True + + ds_tool.get_runtime_parameters.return_value = [param] + + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool) + + assert prompt is not None + + +class TestBaseAgentRunnerInit: + def test_init_sets_stream_tool_call_and_files(self, mocker): + session = mocker.MagicMock() + session.query.return_value.where.return_value.count.return_value = 2 + mocker.patch.object(module.db, "session", session) + + mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[]) + mocker.patch.object(module.DatasetRetrieverTool, "get_dataset_tools", return_value=["ds_tool"]) + + llm = mocker.MagicMock() + llm.get_model_schema.return_value = mocker.MagicMock( + features=[module.ModelFeature.STREAM_TOOL_CALL, module.ModelFeature.VISION] + ) + model_instance = mocker.MagicMock(model_type_instance=llm, model="m", credentials="c") + + app_config = mocker.MagicMock() + app_config.app_id = "app1" + app_config.agent = None + app_config.dataset = mocker.MagicMock(dataset_ids=["d1"], retrieve_config={"k": "v"}) + app_config.additional_features = mocker.MagicMock(show_retrieve_source=True) + + app_generate = mocker.MagicMock(invoke_from="test", inputs={}, files=["file1"]) + message = mocker.MagicMock(id="msg1", conversation_id="conv1") + + runner = BaseAgentRunner( + tenant_id="tenant", + application_generate_entity=app_generate, + conversation=mocker.MagicMock(), + app_config=app_config, + model_config=mocker.MagicMock(), + config=mocker.MagicMock(), + queue_manager=mocker.MagicMock(), + message=message, + user_id="user", + model_instance=model_instance, + ) + + assert runner.stream_tool_call is True + assert runner.files == ["file1"] + assert runner.dataset_tools == ["ds_tool"] + assert runner.agent_thought_count == 2 + + +class TestBaseAgentRunnerCoverage: + def test_convert_tool_skips_non_llm_param(self, runner, mocker): + tool = mocker.MagicMock(tool_name="tool1") + + param = mocker.MagicMock() + param.form = "NOT_LLM" + param.type = mocker.MagicMock() + + tool_entity = mocker.MagicMock() + tool_entity.entity.description.llm = "desc" + tool_entity.get_merged_runtime_parameters.return_value = [param] + + mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool) + + assert prompt_tool.parameters["properties"] == {} + + def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker): + dataset_tool = mocker.MagicMock() + dataset_tool.entity.identity.name = "ds" + runner.dataset_tools = [dataset_tool] + + mocker.patch.object(runner, "_convert_dataset_retriever_tool_to_prompt_message_tool", return_value=MagicMock()) + + tools, prompt_tools = runner._init_prompt_tools() + + assert tools["ds"] == dataset_tool + assert len(prompt_tools) == 1 + + def test_update_prompt_message_tool_select_enum(self, runner, mocker): + tool = mocker.MagicMock() + + option1 = mocker.MagicMock(value="A") + option2 = mocker.MagicMock(value="B") + + param = mocker.MagicMock() + param.form = module.ToolParameter.ToolParameterForm.LLM + param.name = "select_param" + param.required = False + param.llm_description = "desc" + param.input_schema = None + param.options = [option1, option2] + param.type = module.ToolParameter.ToolParameterType.SELECT + + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + + assert result.parameters["properties"]["select_param"]["enum"] == ["A", "B"] + + def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1" + agent.tool_labels = {} + agent.thought = "" + mock_db_session.scalar.return_value = agent + + mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None) + + tool_input = {"a": 1} + observation = {"b": 2} + tool_meta = {"c": 3} + + real_dumps = json.dumps + + def dumps_side_effect(value, *args, **kwargs): + if value in (tool_input, observation, tool_meta) and kwargs.get("ensure_ascii") is False: + raise TypeError("fail") + return real_dumps(value, *args, **kwargs) + + mocker.patch.object(module.json, "dumps", side_effect=dumps_side_effect) + + runner.save_agent_thought( + "id", + "tool1", + tool_input, + None, + observation, + tool_meta, + None, + [], + None, + ) + + assert isinstance(agent.tool_input, str) + assert isinstance(agent.observation, str) + assert isinstance(agent.tool_meta_str, str) + + def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1;;" + agent.tool_labels = {} + agent.thought = "" + mock_db_session.scalar.return_value = agent + + mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None) + + runner.save_agent_thought("id", None, None, None, None, None, None, [], None) + + labels = json.loads(agent.tool_labels_str) + assert "" not in labels + + def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker): + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [] + mocker.patch.object(module, "extract_thread_messages", return_value=[]) + + system_message = module.SystemPromptMessage(content="sys") + + result = runner.organize_agent_history([system_message]) + + assert system_message in result + + def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock( + tool="tool1", + tool_input=None, + observation=None, + thought="thinking", + ) + msg = mocker.MagicMock(id="m6", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + mocker.patch("uuid.uuid4", return_value="uuid") + + mocker.patch.object( + runner, + "organize_agent_user_prompt", + return_value=module.UserPromptMessage(content="user"), + ) + + result = runner.organize_agent_history([]) + + assert any(isinstance(item, module.ToolPromptMessage) for item in result) diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py new file mode 100644 index 0000000000..f6d1edbaf0 --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -0,0 +1,551 @@ +import json +from unittest.mock import MagicMock + +import pytest + +from core.agent.cot_agent_runner import CotAgentRunner +from core.agent.entities import AgentScratchpadUnit +from core.agent.errors import AgentMaxIterationError +from dify_graph.model_runtime.entities.llm_entities import LLMUsage + + +class DummyRunner(CotAgentRunner): + """Concrete implementation for testing abstract methods.""" + + def __init__(self, **kwargs): + # Completely bypass BaseAgentRunner __init__ to avoid DB/session usage + for k, v in kwargs.items(): + setattr(self, k, v) + # Minimal required defaults + self.history_prompt_messages = [] + self.memory = None + + def _organize_prompt_messages(self): + return [] + + +@pytest.fixture +def runner(mocker): + # Prevent BaseAgentRunner __init__ from hitting database + mocker.patch( + "core.agent.base_agent_runner.BaseAgentRunner.organize_agent_history", + return_value=[], + ) + # Prepare required constructor dependencies for BaseAgentRunner + application_generate_entity = MagicMock() + application_generate_entity.model_conf = MagicMock() + application_generate_entity.model_conf.stop = [] + application_generate_entity.model_conf.provider = "openai" + application_generate_entity.model_conf.parameters = {} + application_generate_entity.trace_manager = None + application_generate_entity.invoke_from = "test" + + app_config = MagicMock() + app_config.agent = MagicMock() + app_config.agent.max_iteration = 1 + app_config.prompt_template.simple_prompt_template = "Hello {{name}}" + + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.model_name = "test-model" + model_instance.invoke_llm.return_value = [] + + model_config = MagicMock() + model_config.model = "test-model" + + queue_manager = MagicMock() + message = MagicMock() + + runner = DummyRunner( + tenant_id="tenant", + application_generate_entity=application_generate_entity, + conversation=MagicMock(), + app_config=app_config, + model_config=model_config, + config=MagicMock(), + queue_manager=queue_manager, + message=message, + user_id="user", + model_instance=model_instance, + ) + + # Patch internal methods to isolate behavior + runner._repack_app_generate_entity = MagicMock() + runner._init_prompt_tools = MagicMock(return_value=({}, [])) + runner.recalc_llm_max_tokens = MagicMock() + runner.create_agent_thought = MagicMock(return_value="thought-id") + runner.save_agent_thought = MagicMock() + runner.update_prompt_message_tool = MagicMock() + runner.agent_callback = None + runner.memory = None + runner.history_prompt_messages = [] + + return runner + + +class TestFillInputs: + @pytest.mark.parametrize( + ("instruction", "inputs", "expected"), + [ + ("Hello {{name}}", {"name": "John"}, "Hello John"), + ("No placeholders", {"name": "John"}, "No placeholders"), + ("{{a}}{{b}}", {"a": 1, "b": 2}, "12"), + ("{{x}}", {"x": None}, "None"), + ("", {"x": "y"}, ""), + ], + ) + def test_fill_in_inputs(self, runner, instruction, inputs, expected): + result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs) + assert result == expected + + +class TestConvertDictToAction: + def test_convert_valid_dict(self, runner): + action_dict = {"action": "test", "action_input": {"a": 1}} + action = runner._convert_dict_to_action(action_dict) + assert action.action_name == "test" + assert action.action_input == {"a": 1} + + def test_convert_missing_keys(self, runner): + with pytest.raises(KeyError): + runner._convert_dict_to_action({"invalid": 1}) + + +class TestFormatAssistantMessage: + def test_format_assistant_message_multiple_scratchpads(self, runner): + sp1 = AgentScratchpadUnit( + agent_response="resp1", + thought="thought1", + action_str="action1", + action=AgentScratchpadUnit.Action(action_name="tool", action_input={}), + observation="obs1", + ) + sp2 = AgentScratchpadUnit( + agent_response="final", + thought="", + action_str="", + action=AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done"), + observation=None, + ) + result = runner._format_assistant_message([sp1, sp2]) + assert "Final Answer:" in result + + def test_format_with_final(self, runner): + scratchpad = AgentScratchpadUnit( + agent_response="Done", + thought="", + action_str="", + action=None, + observation=None, + ) + # Simulate final state via action name + scratchpad.action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="Done") + result = runner._format_assistant_message([scratchpad]) + assert "Final Answer" in result + + def test_format_with_action_and_observation(self, runner): + scratchpad = AgentScratchpadUnit( + agent_response="resp", + thought="thinking", + action_str="action", + action=None, + observation="obs", + ) + # Non-final state: provide a non-final action + scratchpad.action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + result = runner._format_assistant_message([scratchpad]) + assert "Thought:" in result + assert "Action:" in result + assert "Observation:" in result + + +class TestHandleInvokeAction: + def test_handle_invoke_action_tool_not_present(self, runner): + action = AgentScratchpadUnit.Action(action_name="missing", action_input={}) + response, meta = runner._handle_invoke_action(action, {}, []) + assert "there is not a tool named" in response + + def test_tool_with_json_string_args(self, runner, mocker): + action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1})) + tool_instance = MagicMock() + tool_instances = {"tool": tool_instance} + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("result", [], MagicMock(to_dict=lambda: {})), + ) + + response, meta = runner._handle_invoke_action(action, tool_instances, []) + assert response == "result" + + +class TestOrganizeHistoricPromptMessages: + def test_empty_history(self, runner, mocker): + mocker.patch( + "core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt", + return_value=[], + ) + result = runner._organize_historic_prompt_messages([]) + assert result == [] + + +class TestRun: + def test_run_handles_empty_parser_output(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[], + ) + + results = list(runner.run(message, "query", {})) + assert isinstance(results, list) + + def test_run_with_action_and_tool_invocation(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", [], MagicMock(to_dict=lambda: {})), + ) + + runner.agent_callback = None + + with pytest.raises(AgentMaxIterationError): + list(runner.run(message, "query", {"tool": MagicMock()})) + + def test_run_respects_max_iteration_boundary(self, runner, mocker): + runner.app_config.agent.max_iteration = 1 + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", [], MagicMock(to_dict=lambda: {})), + ) + + runner.agent_callback = None + + with pytest.raises(AgentMaxIterationError): + list(runner.run(message, "query", {"tool": MagicMock()})) + + def test_run_basic_flow(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[], + ) + + results = list(runner.run(message, "query", {"name": "John"})) + assert results + + def test_run_max_iteration_error(self, runner, mocker): + runner.app_config.agent.max_iteration = 0 + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + with pytest.raises(AgentMaxIterationError): + list(runner.run(message, "query", {})) + + def test_run_increase_usage_aggregation(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + runner.app_config.agent.max_iteration = 2 + + usage_1 = LLMUsage.empty_usage() + usage_1.prompt_tokens = 1 + usage_1.completion_tokens = 1 + usage_1.total_tokens = 2 + usage_1.prompt_price = 1 + usage_1.completion_price = 1 + usage_1.total_price = 2 + + usage_2 = LLMUsage.empty_usage() + usage_2.prompt_tokens = 1 + usage_2.completion_tokens = 1 + usage_2.total_tokens = 2 + usage_2.prompt_price = 1 + usage_2.completion_price = 1 + usage_2.total_price = 2 + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + handle_output = mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + side_effect=[ + [action], + [], + ], + ) + + def _handle_side_effect(chunks, usage_dict): + call_index = handle_output.call_count + usage_dict["usage"] = usage_1 if call_index == 1 else usage_2 + return [action] if call_index == 1 else [] + + handle_output.side_effect = _handle_side_effect + runner.model_instance.invoke_llm = MagicMock(return_value=[]) + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", [], MagicMock(to_dict=lambda: {})), + ) + + fake_prompt_tool = MagicMock() + fake_prompt_tool.name = "tool" + runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool])) + + results = list(runner.run(message, "query", {})) + final_usage = results[-1].delta.usage + assert final_usage is not None + assert final_usage.prompt_tokens == 2 + assert final_usage.completion_tokens == 2 + assert final_usage.total_tokens == 4 + assert final_usage.prompt_price == 2 + assert final_usage.completion_price == 2 + assert final_usage.total_price == 4 + + def test_run_when_no_action_branch(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[], + ) + + results = list(runner.run(message, "query", {})) + assert results[-1].delta.message.content == "" + + def test_run_usage_missing_key_branch(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[], + ) + + runner.model_instance.invoke_llm = MagicMock(return_value=[]) + + list(runner.run(message, "query", {})) + + def test_run_prompt_tool_update_branch(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + # First iteration → action + # Second iteration → no action (empty list) + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + side_effect=[[action], []], + ) + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", [], MagicMock(to_dict=lambda: {})), + ) + + runner.app_config.agent.max_iteration = 5 + + fake_prompt_tool = MagicMock() + fake_prompt_tool.name = "tool" + + runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool])) + + runner.update_prompt_message_tool = MagicMock() + runner.agent_callback = None + + list(runner.run(message, "query", {})) + + runner.update_prompt_message_tool.assert_called_once() + + def test_historic_with_assistant_and_tool_calls(self, runner): + from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage + + assistant = AssistantPromptMessage(content="thinking") + assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))] + + tool_msg = ToolPromptMessage(content="obs", tool_call_id="1") + + runner.history_prompt_messages = [assistant, tool_msg] + + result = runner._organize_historic_prompt_messages([]) + assert isinstance(result, list) + + def test_historic_final_flush_branch(self, runner): + from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + + assistant = AssistantPromptMessage(content="final") + runner.history_prompt_messages = [assistant] + + result = runner._organize_historic_prompt_messages([]) + assert isinstance(result, list) + + +class TestInitReactState: + def test_init_react_state_resets_state(self, runner, mocker): + mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"]) + runner._agent_scratchpad = ["old"] + runner._query = "old" + + runner._init_react_state("new-query") + + assert runner._query == "new-query" + assert runner._agent_scratchpad == [] + assert runner._historic_prompt_messages == ["historic"] + + +class TestHandleInvokeActionExtended: + def test_tool_with_invalid_json_string_args(self, runner, mocker): + action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json") + tool_instance = MagicMock() + tool_instances = {"tool": tool_instance} + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", ["file1"], MagicMock(to_dict=lambda: {"k": "v"})), + ) + + message_file_ids = [] + response, meta = runner._handle_invoke_action(action, tool_instances, message_file_ids) + + assert response == "ok" + assert message_file_ids == ["file1"] + runner.queue_manager.publish.assert_called() + + +class TestFillInputsEdgeCases: + def test_fill_inputs_with_empty_inputs(self, runner): + result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {}) + assert result == "Hello {{x}}" + + def test_fill_inputs_with_exception_in_replace(self, runner): + class BadValue: + def __str__(self): + raise Exception("fail") + + # Should silently continue on exception + result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {"x": BadValue()}) + assert result == "Hello {{x}}" + + +class TestOrganizeHistoricPromptMessagesExtended: + def test_user_message_flushes_scratchpad(self, runner, mocker): + from dify_graph.model_runtime.entities.message_entities import UserPromptMessage + + user_message = UserPromptMessage(content="Hi") + + runner.history_prompt_messages = [user_message] + + mock_transform = mocker.patch( + "core.agent.cot_agent_runner.AgentHistoryPromptTransform", + ) + mock_transform.return_value.get_prompt.return_value = ["final"] + + result = runner._organize_historic_prompt_messages([]) + assert result == ["final"] + + def test_tool_message_without_scratchpad_raises(self, runner): + from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage + + runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")] + + with pytest.raises(NotImplementedError): + runner._organize_historic_prompt_messages([]) + + def test_agent_history_transform_invocation(self, runner, mocker): + mock_transform = MagicMock() + mock_transform.get_prompt.return_value = [] + + mocker.patch( + "core.agent.cot_agent_runner.AgentHistoryPromptTransform", + return_value=mock_transform, + ) + + runner.history_prompt_messages = [] + result = runner._organize_historic_prompt_messages([]) + assert result == [] + + +class TestRunAdditionalBranches: + def test_run_with_no_action_final_answer_empty(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=["thinking"], + ) + + results = list(runner.run(message, "query", {})) + assert any(hasattr(r, "delta") for r in results) + + def test_run_with_final_answer_action_string(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done") + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + results = list(runner.run(message, "query", {})) + assert results[-1].delta.message.content == "done" + + def test_run_with_final_answer_action_dict(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input={"a": 1}) + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + results = list(runner.run(message, "query", {})) + assert json.loads(results[-1].delta.message.content) == {"a": 1} + + def test_run_with_string_final_answer(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + # Remove invalid branch: Pydantic enforces str|dict for action_input + action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="12345") + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + results = list(runner.run(message, "query", {})) + assert results[-1].delta.message.content == "12345" diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py new file mode 100644 index 0000000000..f9d69d1196 --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -0,0 +1,215 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.agent.cot_chat_agent_runner import CotChatAgentRunner +from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent +from tests.unit_tests.core.agent.conftest import ( + DummyAgentConfig, + DummyAppConfig, + DummyTool, +) +from tests.unit_tests.core.agent.conftest import ( + DummyPromptEntity as DummyPrompt, +) + + +class DummyFileUploadConfig: + def __init__(self, image_config=None): + self.image_config = image_config + + +class DummyImageConfig: + def __init__(self, detail=None): + self.detail = detail + + +class DummyGenerateEntity: + def __init__(self, file_upload_config=None): + self.file_upload_config = file_upload_config + + +class DummyUnit: + def __init__(self, final=False, thought=None, action_str=None, observation=None, agent_response=None): + self._final = final + self.thought = thought + self.action_str = action_str + self.observation = observation + self.agent_response = agent_response + + def is_final(self): + return self._final + + +@pytest.fixture +def runner(): + runner = CotChatAgentRunner.__new__(CotChatAgentRunner) + runner._instruction = "test_instruction" + runner._prompt_messages_tools = [DummyTool("tool1"), DummyTool("tool2")] + runner._query = "user query" + runner._agent_scratchpad = [] + runner.files = [] + runner.application_generate_entity = DummyGenerateEntity() + runner._organize_historic_prompt_messages = MagicMock(return_value=["historic"]) + return runner + + +class TestOrganizeSystemPrompt: + def test_organize_system_prompt_success(self, runner, mocker): + first_prompt = "Instruction: {{instruction}}, Tools: {{tools}}, Names: {{tool_names}}" + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt(first_prompt))) + + mocker.patch( + "core.agent.cot_chat_agent_runner.jsonable_encoder", + return_value=[{"name": "tool1"}, {"name": "tool2"}], + ) + + result = runner._organize_system_prompt() + + assert "test_instruction" in result.content + assert "tool1" in result.content + assert "tool2" in result.content + assert "tool1, tool2" in result.content + + def test_organize_system_prompt_missing_agent(self, runner): + runner.app_config = DummyAppConfig(agent=None) + with pytest.raises(AssertionError): + runner._organize_system_prompt() + + def test_organize_system_prompt_missing_prompt(self, runner): + runner.app_config = DummyAppConfig(DummyAgentConfig(prompt_entity=None)) + with pytest.raises(AssertionError): + runner._organize_system_prompt() + + +class TestOrganizeUserQuery: + @pytest.mark.parametrize("files", [None, pytest.param([], id="empty_list")]) + def test_organize_user_query_no_files(self, runner, files): + runner.files = files + result = runner._organize_user_query("query", []) + assert len(result) == 1 + assert result[0].content == "query" + + @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") + @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") + def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner): + from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + + mock_content = ImagePromptMessageContent( + url="http://test", + format="png", + mime_type="image/png", + ) + mock_to_prompt.return_value = mock_content + mock_user_prompt.side_effect = lambda content: MagicMock(content=content) + + runner.files = ["file1"] + runner.application_generate_entity = DummyGenerateEntity(None) + + result = runner._organize_user_query("query", []) + assert len(result) == 1 + assert isinstance(result[0].content, list) + assert mock_content in result[0].content + mock_to_prompt.assert_called_once_with( + "file1", + image_detail_config=ImagePromptMessageContent.DETAIL.LOW, + ) + + @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") + @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") + def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner): + from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + + mock_content = ImagePromptMessageContent( + url="http://test", + format="png", + mime_type="image/png", + ) + mock_to_prompt.return_value = mock_content + mock_user_prompt.side_effect = lambda content: MagicMock(content=content) + + runner.files = ["file1"] + + image_config = DummyImageConfig(detail="high") + runner.application_generate_entity = DummyGenerateEntity(DummyFileUploadConfig(image_config)) + + result = runner._organize_user_query("query", []) + assert len(result) == 1 + assert isinstance(result[0].content, list) + assert mock_content in result[0].content + mock_to_prompt.assert_called_once_with( + "file1", + image_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + ) + + @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") + def test_organize_user_query_with_text_file_no_config(self, mock_to_prompt, runner): + mock_to_prompt.return_value = TextPromptMessageContent(data="file_content") + runner.files = ["file1"] + runner.application_generate_entity = DummyGenerateEntity(None) + + result = runner._organize_user_query("query", []) + assert len(result) == 1 + assert isinstance(result[0].content, list) + + +class TestOrganizePromptMessages: + def test_no_scratchpad(self, runner, mocker): + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}"))) + runner._organize_system_prompt = MagicMock(return_value="system") + runner._organize_user_query = MagicMock(return_value=["query"]) + + result = runner._organize_prompt_messages() + assert "system" in result + assert "query" in result + runner._organize_historic_prompt_messages.assert_called_once() + + def test_with_final_scratchpad(self, runner, mocker): + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}"))) + runner._organize_system_prompt = MagicMock(return_value="system") + runner._organize_user_query = MagicMock(return_value=["query"]) + + unit = DummyUnit(final=True, agent_response="done") + runner._agent_scratchpad = [unit] + + result = runner._organize_prompt_messages() + assistant_msgs = [m for m in result if hasattr(m, "content")] + combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)]) + assert "Final Answer: done" in combined + + def test_with_thought_action_observation(self, runner, mocker): + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}"))) + runner._organize_system_prompt = MagicMock(return_value="system") + runner._organize_user_query = MagicMock(return_value=["query"]) + + unit = DummyUnit( + final=False, + thought="thinking", + action_str="action", + observation="observe", + ) + runner._agent_scratchpad = [unit] + + result = runner._organize_prompt_messages() + assistant_msgs = [m for m in result if hasattr(m, "content")] + combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)]) + assert "Thought: thinking" in combined + assert "Action: action" in combined + assert "Observation: observe" in combined + + def test_multiple_units_mixed(self, runner, mocker): + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}"))) + runner._organize_system_prompt = MagicMock(return_value="system") + runner._organize_user_query = MagicMock(return_value=["query"]) + + units = [ + DummyUnit(final=False, thought="t1"), + DummyUnit(final=True, agent_response="done"), + ] + runner._agent_scratchpad = units + + result = runner._organize_prompt_messages() + assistant_msgs = [m for m in result if hasattr(m, "content")] + combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)]) + assert "Thought: t1" in combined + assert "Final Answer: done" in combined diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py new file mode 100644 index 0000000000..ab822bb57d --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -0,0 +1,234 @@ +import json + +import pytest + +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) + +# ----------------------------- +# Fixtures +# ----------------------------- + + +@pytest.fixture +def runner(mocker, dummy_tool_factory): + runner = CotCompletionAgentRunner.__new__(CotCompletionAgentRunner) + + runner._instruction = "Test instruction" + runner._prompt_messages_tools = [dummy_tool_factory("toolA"), dummy_tool_factory("toolB")] + runner._query = "What is Python?" + runner._agent_scratchpad = [] + + mocker.patch( + "core.agent.cot_completion_agent_runner.jsonable_encoder", + side_effect=lambda tools: [{"name": t.name} for t in tools], + ) + + return runner + + +# ====================================================== +# _organize_instruction_prompt Tests +# ====================================================== + + +class TestOrganizeInstructionPrompt: + def test_success_all_placeholders( + self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory + ): + template = ( + "{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}" + ) + + runner.app_config = dummy_app_config_factory( + agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template)) + ) + + result = runner._organize_instruction_prompt() + + assert "Test instruction" in result + assert "toolA" in result + assert "toolB" in result + tools_payload = json.loads(result.split(" | ")[1]) + assert {item["name"] for item in tools_payload} == {"toolA", "toolB"} + + def test_agent_none_raises(self, runner, dummy_app_config_factory): + runner.app_config = dummy_app_config_factory(agent=None) + with pytest.raises(ValueError, match="Agent configuration is not set"): + runner._organize_instruction_prompt() + + def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory): + runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None)) + with pytest.raises(ValueError, match="prompt entity is not set"): + runner._organize_instruction_prompt() + + +# ====================================================== +# _organize_historic_prompt Tests +# ====================================================== + + +class TestOrganizeHistoricPrompt: + def test_with_user_and_assistant_string(self, runner, mocker): + user_msg = UserPromptMessage(content="Hello") + assistant_msg = AssistantPromptMessage(content="Hi there") + + mocker.patch.object( + runner, + "_organize_historic_prompt_messages", + return_value=[user_msg, assistant_msg], + ) + + result = runner._organize_historic_prompt() + + assert "Question: Hello" in result + assert "Hi there" in result + + def test_assistant_list_with_text_content(self, runner, mocker): + text_content = TextPromptMessageContent(data="Partial answer") + assistant_msg = AssistantPromptMessage(content=[text_content]) + + mocker.patch.object( + runner, + "_organize_historic_prompt_messages", + return_value=[assistant_msg], + ) + + result = runner._organize_historic_prompt() + + assert "Partial answer" in result + + def test_assistant_list_with_non_text_content_ignored(self, runner, mocker): + non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png") + assistant_msg = AssistantPromptMessage(content=[non_text_content]) + + mocker.patch.object( + runner, + "_organize_historic_prompt_messages", + return_value=[assistant_msg], + ) + + result = runner._organize_historic_prompt() + assert result == "" + + def test_empty_history(self, runner, mocker): + mocker.patch.object( + runner, + "_organize_historic_prompt_messages", + return_value=[], + ) + + result = runner._organize_historic_prompt() + assert result == "" + + +# ====================================================== +# _organize_prompt_messages Tests +# ====================================================== + + +class TestOrganizePromptMessages: + def test_full_flow_with_scratchpad( + self, + runner, + mocker, + dummy_app_config_factory, + dummy_agent_config_factory, + dummy_prompt_entity_factory, + dummy_scratchpad_unit_factory, + ): + template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}" + + runner.app_config = dummy_app_config_factory( + agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template)) + ) + + mocker.patch.object(runner, "_organize_historic_prompt", return_value="History\n") + + runner._agent_scratchpad = [ + dummy_scratchpad_unit_factory(final=False, thought="Thinking", action_str="Act", observation="Obs"), + dummy_scratchpad_unit_factory(final=True, agent_response="Done"), + ] + + result = runner._organize_prompt_messages() + + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], UserPromptMessage) + + content = result[0].content + + assert "History" in content + assert "Thought: Thinking" in content + assert "Action: Act" in content + assert "Observation: Obs" in content + assert "Final Answer: Done" in content + assert "Question: What is Python?" in content + + def test_no_scratchpad( + self, runner, mocker, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory + ): + template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}" + + runner.app_config = dummy_app_config_factory( + agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template)) + ) + + mocker.patch.object(runner, "_organize_historic_prompt", return_value="") + + runner._agent_scratchpad = None + + result = runner._organize_prompt_messages() + + assert "Question: What is Python?" in result[0].content + + @pytest.mark.parametrize( + ("thought", "action", "observation"), + [ + ("T", None, None), + ("T", "A", None), + ("T", None, "O"), + ], + ) + def test_partial_scratchpad_units( + self, + runner, + mocker, + thought, + action, + observation, + dummy_app_config_factory, + dummy_agent_config_factory, + dummy_prompt_entity_factory, + dummy_scratchpad_unit_factory, + ): + template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}" + + runner.app_config = dummy_app_config_factory( + agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template)) + ) + + mocker.patch.object(runner, "_organize_historic_prompt", return_value="") + + runner._agent_scratchpad = [ + dummy_scratchpad_unit_factory( + final=False, + thought=thought, + action_str=action, + observation=observation, + ) + ] + + result = runner._organize_prompt_messages() + content = result[0].content + + assert "Thought:" in content + if action: + assert "Action:" in content + if observation: + assert "Observation:" in content diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py new file mode 100644 index 0000000000..299c9b31d2 --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -0,0 +1,452 @@ +import json +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from core.agent.errors import AgentMaxIterationError +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueMessageFileEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( + DocumentPromptMessageContent, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) + +# ============================== +# Dummy Helper Classes +# ============================== + + +def build_usage(pt=1, ct=1, tt=2) -> LLMUsage: + usage = LLMUsage.empty_usage() + usage.prompt_tokens = pt + usage.completion_tokens = ct + usage.total_tokens = tt + usage.prompt_price = 0 + usage.completion_price = 0 + usage.total_price = 0 + return usage + + +class DummyMessage: + def __init__(self, content: str | None = None, tool_calls: list[Any] | None = None): + self.content: str | None = content + self.tool_calls: list[Any] = tool_calls or [] + + +class DummyDelta: + def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None): + self.message: DummyMessage | None = message + self.usage: LLMUsage | None = usage + + +class DummyChunk: + def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None): + self.delta: DummyDelta = DummyDelta(message=message, usage=usage) + + +class DummyResult: + def __init__( + self, + message: DummyMessage | None = None, + usage: LLMUsage | None = None, + prompt_messages: list[DummyMessage] | None = None, + ): + self.message: DummyMessage | None = message + self.usage: LLMUsage | None = usage + self.prompt_messages: list[DummyMessage] = prompt_messages or [] + self.system_fingerprint: str = "" + + +# ============================== +# Fixtures +# ============================== + + +@pytest.fixture +def runner(mocker): + # Completely bypass BaseAgentRunner __init__ to avoid DB / Flask context + mocker.patch( + "core.agent.base_agent_runner.BaseAgentRunner.__init__", + return_value=None, + ) + + # Patch streaming chunk models to avoid validation on dummy message objects + mocker.patch("core.agent.fc_agent_runner.LLMResultChunk", MagicMock) + mocker.patch("core.agent.fc_agent_runner.LLMResultChunkDelta", MagicMock) + + app_config = MagicMock() + app_config.agent = MagicMock(max_iteration=2) + app_config.prompt_template = MagicMock(simple_prompt_template="system") + + application_generate_entity = MagicMock() + application_generate_entity.model_conf = MagicMock(parameters={}, stop=None) + application_generate_entity.trace_manager = MagicMock() + application_generate_entity.invoke_from = "test" + application_generate_entity.app_config = MagicMock(app_id="app") + application_generate_entity.file_upload_config = None + + queue_manager = MagicMock() + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.model_name = "test-model" + + message = MagicMock(id="msg1") + conversation = MagicMock(id="conv1") + + runner = FunctionCallAgentRunner( + tenant_id="tenant", + application_generate_entity=application_generate_entity, + conversation=conversation, + app_config=app_config, + model_config=MagicMock(), + config=MagicMock(), + queue_manager=queue_manager, + message=message, + user_id="user", + model_instance=model_instance, + ) + + # Manually inject required attributes normally set by BaseAgentRunner + runner.tenant_id = "tenant" + runner.application_generate_entity = application_generate_entity + runner.conversation = conversation + runner.app_config = app_config + runner.model_config = MagicMock() + runner.config = MagicMock() + runner.queue_manager = queue_manager + runner.message = message + runner.user_id = "user" + runner.model_instance = model_instance + + runner.stream_tool_call = False + runner.memory = None + runner.history_prompt_messages = [] + runner._current_thoughts = [] + runner.files = [] + runner.agent_callback = MagicMock() + + runner._init_prompt_tools = MagicMock(return_value=({}, [])) + runner.create_agent_thought = MagicMock(return_value="thought1") + runner.save_agent_thought = MagicMock() + runner.recalc_llm_max_tokens = MagicMock() + runner.update_prompt_message_tool = MagicMock() + + return runner + + +# ============================== +# Tool Call Checks +# ============================== + + +class TestToolCallChecks: + @pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)]) + def test_check_tool_calls(self, runner, tool_calls, expected): + chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls)) + assert runner.check_tool_calls(chunk) is expected + + @pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)]) + def test_check_blocking_tool_calls(self, runner, tool_calls, expected): + result = DummyResult(message=DummyMessage(tool_calls=tool_calls)) + assert runner.check_blocking_tool_calls(result) is expected + + +# ============================== +# Extract Tool Calls +# ============================== + + +class TestExtractToolCalls: + def test_extract_tool_calls_with_valid_json(self, runner): + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = json.dumps({"a": 1}) + + chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call])) + calls = runner.extract_tool_calls(chunk) + + assert calls == [("1", "tool", {"a": 1})] + + def test_extract_tool_calls_empty_arguments(self, runner): + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = "" + + chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call])) + calls = runner.extract_tool_calls(chunk) + + assert calls == [("1", "tool", {})] + + def test_extract_blocking_tool_calls(self, runner): + tool_call = MagicMock() + tool_call.id = "2" + tool_call.function.name = "block" + tool_call.function.arguments = json.dumps({"x": 2}) + + result = DummyResult(message=DummyMessage(tool_calls=[tool_call])) + calls = runner.extract_blocking_tool_calls(result) + + assert calls == [("2", "block", {"x": 2})] + + +# ============================== +# System Message Initialization +# ============================== + + +class TestInitSystemMessage: + def test_init_system_message_empty_prompt_messages(self, runner): + result = runner._init_system_message("system", []) + assert len(result) == 1 + + def test_init_system_message_insert_at_start(self, runner): + msgs = [MagicMock()] + result = runner._init_system_message("system", msgs) + assert result[0].content == "system" + + def test_init_system_message_no_template(self, runner): + result = runner._init_system_message("", []) + assert result == [] + + +# ============================== +# Organize User Query +# ============================== + + +class TestOrganizeUserQuery: + def test_without_files(self, runner): + result = runner._organize_user_query("query", []) + assert len(result) == 1 + + def test_with_none_query(self, runner): + result = runner._organize_user_query(None, []) + assert len(result) == 1 + + def test_with_files_uses_image_detail_config(self, runner, mocker): + file_content = TextPromptMessageContent(data="file-content") + mock_to_prompt = mocker.patch( + "core.agent.fc_agent_runner.file_manager.to_prompt_message_content", + return_value=file_content, + ) + + image_config = MagicMock(detail=ImagePromptMessageContent.DETAIL.HIGH) + runner.application_generate_entity.file_upload_config = MagicMock(image_config=image_config) + runner.files = ["file1"] + + result = runner._organize_user_query("query", []) + + assert len(result) == 1 + assert isinstance(result[0].content, list) + mock_to_prompt.assert_called_once_with("file1", image_detail_config=ImagePromptMessageContent.DETAIL.HIGH) + + +# ============================== +# Clear User Prompt Images +# ============================== + + +class TestClearUserPromptImageMessages: + def test_clear_text_and_image_content(self, runner): + text = MagicMock() + text.type = "text" + text.data = "hello" + + image = MagicMock() + image.type = "image" + image.data = "img" + + user_msg = MagicMock() + user_msg.__class__.__name__ = "UserPromptMessage" + user_msg.content = [text, image] + + result = runner._clear_user_prompt_image_messages([user_msg]) + assert isinstance(result, list) + + def test_clear_includes_file_placeholder(self, runner): + text = TextPromptMessageContent(data="hello") + image = ImagePromptMessageContent(format="url", mime_type="image/png") + document = DocumentPromptMessageContent(format="url", mime_type="application/pdf") + + user_msg = UserPromptMessage(content=[text, image, document]) + + result = runner._clear_user_prompt_image_messages([user_msg]) + + assert result[0].content == "hello\n[image]\n[file]" + + +# ============================== +# Run Method Tests +# ============================== + + +class TestRunMethod: + def test_run_non_streaming_no_tool_calls(self, runner): + message = MagicMock(id="m1") + dummy_message = DummyMessage(content="hello") + result = DummyResult(message=dummy_message, usage=build_usage()) + + runner.model_instance.invoke_llm.return_value = result + + outputs = list(runner.run(message, "query")) + assert len(outputs) == 1 + runner.queue_manager.publish.assert_called() + + queue_calls = runner.queue_manager.publish.call_args_list + assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls) + + def test_run_streaming_branch(self, runner): + message = MagicMock(id="m1") + runner.stream_tool_call = True + + content = [TextPromptMessageContent(data="hi")] + chunk = DummyChunk(message=DummyMessage(content=content), usage=build_usage()) + + def generator(): + yield chunk + + runner.model_instance.invoke_llm.return_value = generator() + + outputs = list(runner.run(message, "query")) + assert len(outputs) == 1 + + def test_run_streaming_tool_calls_list_content(self, runner): + message = MagicMock(id="m1") + runner.stream_tool_call = True + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = json.dumps({"a": 1}) + + content = [TextPromptMessageContent(data="hi")] + chunk = DummyChunk(message=DummyMessage(content=content, tool_calls=[tool_call]), usage=build_usage()) + + def generator(): + yield chunk + + final_message = DummyMessage(content="done", tool_calls=[]) + final_result = DummyResult(message=final_message, usage=build_usage()) + + runner.model_instance.invoke_llm.side_effect = [generator(), final_result] + + outputs = list(runner.run(message, "query")) + assert len(outputs) >= 1 + + def test_run_non_streaming_list_content(self, runner): + message = MagicMock(id="m1") + content = [TextPromptMessageContent(data="hi")] + dummy_message = DummyMessage(content=content) + result = DummyResult(message=dummy_message, usage=build_usage()) + + runner.model_instance.invoke_llm.return_value = result + + outputs = list(runner.run(message, "query")) + assert len(outputs) == 1 + assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi" + + def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker): + message = MagicMock(id="m1") + runner.stream_tool_call = True + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = json.dumps({"a": 1}) + + chunk = DummyChunk(message=DummyMessage(content="hi", tool_calls=[tool_call]), usage=build_usage()) + + def generator(): + yield chunk + + runner.model_instance.invoke_llm.return_value = generator() + + real_dumps = json.dumps + + def flaky_dumps(obj, *args, **kwargs): + if kwargs.get("ensure_ascii") is False: + return real_dumps(obj, *args, **kwargs) + raise TypeError("boom") + + mocker.patch("core.agent.fc_agent_runner.json.dumps", side_effect=flaky_dumps) + + outputs = list(runner.run(message, "query")) + assert len(outputs) == 1 + + def test_run_with_missing_tool_instance(self, runner): + message = MagicMock(id="m1") + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "missing" + tool_call.function.arguments = json.dumps({}) + + dummy_message = DummyMessage(content="", tool_calls=[tool_call]) + result = DummyResult(message=dummy_message, usage=build_usage()) + final_message = DummyMessage(content="done", tool_calls=[]) + final_result = DummyResult(message=final_message, usage=build_usage()) + + runner.model_instance.invoke_llm.side_effect = [result, final_result] + + outputs = list(runner.run(message, "query")) + assert len(outputs) >= 1 + + def test_run_with_tool_instance_and_files(self, runner, mocker): + message = MagicMock(id="m1") + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = json.dumps({"a": 1}) + + dummy_message = DummyMessage(content="", tool_calls=[tool_call]) + result = DummyResult(message=dummy_message, usage=build_usage()) + final_result = DummyResult(message=DummyMessage(content="done", tool_calls=[]), usage=build_usage()) + + runner.model_instance.invoke_llm.side_effect = [result, final_result] + + tool_instance = MagicMock() + prompt_tool = MagicMock() + prompt_tool.name = "tool" + runner._init_prompt_tools.return_value = ({"tool": tool_instance}, [prompt_tool]) + + tool_invoke_meta = MagicMock() + tool_invoke_meta.to_dict.return_value = {"ok": True} + mocker.patch( + "core.agent.fc_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", ["file1"], tool_invoke_meta), + ) + + outputs = list(runner.run(message, "query")) + assert len(outputs) >= 1 + assert any( + isinstance(call.args[0], QueueMessageFileEvent) + and call.args[0].message_file_id == "file1" + and call.args[1] == PublishFrom.APPLICATION_MANAGER + for call in runner.queue_manager.publish.call_args_list + ) + + def test_run_max_iteration_error(self, runner): + runner.app_config.agent.max_iteration = 0 + + message = MagicMock(id="m1") + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = "{}" + + dummy_message = DummyMessage(content="", tool_calls=[tool_call]) + result = DummyResult(message=dummy_message, usage=build_usage()) + + runner.model_instance.invoke_llm.return_value = result + + with pytest.raises(AgentMaxIterationError): + list(runner.run(message, "query")) diff --git a/api/tests/unit_tests/core/agent/test_plugin_entities.py b/api/tests/unit_tests/core/agent/test_plugin_entities.py new file mode 100644 index 0000000000..9955190aca --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_plugin_entities.py @@ -0,0 +1,324 @@ +"""Unit tests for core.agent.plugin_entities. + +Covers entities such as AgentFeature, AgentProviderEntityWithPlugin, +AgentStrategyEntity, AgentStrategyIdentity, AgentStrategyParameter, +AgentStrategyProviderEntity, and AgentStrategyProviderIdentity. Tests rely on +Pydantic ValidationError behavior and pytest fixtures for validation and +mocking; ensure entity invariants and validation rules remain stable. +""" + +import pytest +from pydantic import ValidationError + +from core.agent.plugin_entities import ( + AgentFeature, + AgentProviderEntityWithPlugin, + AgentStrategyEntity, + AgentStrategyIdentity, + AgentStrategyParameter, + AgentStrategyProviderEntity, + AgentStrategyProviderIdentity, +) +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolIdentity, ToolProviderIdentity + +# ========================================================= +# Fixtures +# ========================================================= + + +@pytest.fixture +def mock_identity(mocker): + return mocker.MagicMock(spec=AgentStrategyIdentity) + + +@pytest.fixture +def mock_provider_identity(mocker): + return mocker.MagicMock(spec=AgentStrategyProviderIdentity) + + +# ========================================================= +# AgentStrategyParameterType Tests +# ========================================================= + + +class TestAgentStrategyParameterType: + @pytest.mark.parametrize( + "enum_member", + list(AgentStrategyParameter.AgentStrategyParameterType), + ) + def test_as_normal_type_calls_external_function(self, mocker, enum_member) -> None: + mock_func = mocker.patch( + "core.agent.plugin_entities.as_normal_type", + return_value="normalized", + ) + + result = enum_member.as_normal_type() + + mock_func.assert_called_once_with(enum_member) + assert result == "normalized" + + def test_as_normal_type_propagates_exception(self, mocker) -> None: + enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING + mocker.patch( + "core.agent.plugin_entities.as_normal_type", + side_effect=RuntimeError("boom"), + ) + + with pytest.raises(RuntimeError): + enum_member.as_normal_type() + + @pytest.mark.parametrize( + ("enum_member", "value"), + [ + (AgentStrategyParameter.AgentStrategyParameterType.STRING, "abc"), + (AgentStrategyParameter.AgentStrategyParameterType.NUMBER, 10), + (AgentStrategyParameter.AgentStrategyParameterType.BOOLEAN, True), + (AgentStrategyParameter.AgentStrategyParameterType.ANY, {"a": 1}), + (AgentStrategyParameter.AgentStrategyParameterType.STRING, None), + (AgentStrategyParameter.AgentStrategyParameterType.FILES, []), + ], + ) + def test_cast_value_calls_external_function(self, mocker, enum_member, value) -> None: + mock_func = mocker.patch( + "core.agent.plugin_entities.cast_parameter_value", + return_value="casted", + ) + + result = enum_member.cast_value(value) + + mock_func.assert_called_once_with(enum_member, value) + assert result == "casted" + + def test_cast_value_propagates_exception(self, mocker) -> None: + enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING + mocker.patch( + "core.agent.plugin_entities.cast_parameter_value", + side_effect=ValueError("invalid"), + ) + + with pytest.raises(ValueError): + enum_member.cast_value("bad") + + +# ========================================================= +# AgentStrategyParameter Tests +# ========================================================= + + +class TestAgentStrategyParameter: + def test_valid_creation_minimal(self) -> None: + # bypass base PluginParameter required fields using model_construct + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + help=None, + ) + assert param.type == AgentStrategyParameter.AgentStrategyParameterType.STRING + assert param.help is None + + def test_valid_creation_with_help(self) -> None: + help_obj = I18nObject(en_US="test") + + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + help=help_obj, + ) + assert param.help == help_obj + + @pytest.mark.parametrize("invalid_type", [None, "invalid_type", 999, [], {}, ["bad"], {"bad": 1}]) + def test_invalid_type_raises_validation_error(self, invalid_type) -> None: + with pytest.raises(ValidationError) as exc_info: + AgentStrategyParameter(type=invalid_type, name="x", label=I18nObject(en_US="y", zh_Hans="y")) + + assert any(error["loc"] == ("type",) for error in exc_info.value.errors()) + + def test_init_frontend_parameter_calls_external(self, mocker) -> None: + mock_func = mocker.patch( + "core.agent.plugin_entities.init_frontend_parameter", + return_value="frontend", + ) + + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + ) + + result = param.init_frontend_parameter("value") + + mock_func.assert_called_once_with(param, param.type, "value") + assert result == "frontend" + + def test_init_frontend_parameter_propagates_exception(self, mocker) -> None: + mocker.patch( + "core.agent.plugin_entities.init_frontend_parameter", + side_effect=RuntimeError("error"), + ) + + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + ) + + with pytest.raises(RuntimeError): + param.init_frontend_parameter("value") + + +# ========================================================= +# AgentStrategyProviderEntity Tests +# ========================================================= + + +class TestAgentStrategyProviderEntity: + def test_creation_with_plugin_id(self, mock_provider_identity) -> None: + entity = AgentStrategyProviderEntity( + identity=mock_provider_identity, + plugin_id="plugin-123", + ) + assert entity.plugin_id == "plugin-123" + + def test_creation_with_empty_plugin_id(self, mock_provider_identity) -> None: + entity = AgentStrategyProviderEntity( + identity=mock_provider_identity, + plugin_id="", + ) + assert entity.plugin_id == "" + + def test_creation_without_plugin_id(self, mock_provider_identity) -> None: + entity = AgentStrategyProviderEntity(identity=mock_provider_identity) + assert entity.plugin_id is None + + def test_invalid_identity_raises(self) -> None: + with pytest.raises(ValidationError): + AgentStrategyProviderEntity(identity="invalid") + + +# ========================================================= +# AgentStrategyEntity Tests +# ========================================================= + + +class TestAgentStrategyEntity: + def test_parameters_default_empty(self, mock_identity) -> None: + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + ) + assert entity.parameters == [] + + def test_parameters_none_converted_to_empty(self, mock_identity) -> None: + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + parameters=None, + ) + assert entity.parameters == [] + + def test_parameters_preserved(self, mock_identity) -> None: + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + ) + + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + parameters=[param], + ) + assert entity.parameters == [param] + + def test_invalid_parameters_type_raises(self, mock_identity) -> None: + with pytest.raises(ValidationError): + AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + parameters="invalid", + ) + + @pytest.mark.parametrize( + "features", + [ + None, + [], + [AgentFeature.HISTORY_MESSAGES], + ], + ) + def test_features_valid(self, mock_identity, features) -> None: + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + features=features, + ) + assert entity.features == features + + def test_invalid_features_type_raises(self, mock_identity) -> None: + with pytest.raises(ValidationError): + AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + features="invalid", + ) + + def test_output_schema_and_meta_version(self, mock_identity) -> None: + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + output_schema={"type": "object"}, + meta_version="v1", + ) + assert entity.output_schema == {"type": "object"} + assert entity.meta_version == "v1" + + def test_missing_required_fields_raise(self, mock_identity) -> None: + with pytest.raises(ValidationError): + AgentStrategyEntity(identity=mock_identity) + + +# ========================================================= +# AgentProviderEntityWithPlugin Tests +# ========================================================= + + +class TestAgentProviderEntityWithPlugin: + def test_default_strategies_empty(self, mock_provider_identity) -> None: + entity = AgentProviderEntityWithPlugin(identity=mock_provider_identity) + assert entity.strategies == [] + + def test_strategies_assignment(self, mock_provider_identity, mock_identity) -> None: + strategy = AgentStrategyEntity.model_construct( + identity=mock_identity, + description=I18nObject(en_US="test"), + parameters=[], + ) + + entity = AgentProviderEntityWithPlugin( + identity=mock_provider_identity, + strategies=[strategy], + ) + assert entity.strategies == [strategy] + + def test_invalid_strategies_type_raises(self, mock_provider_identity) -> None: + with pytest.raises(ValidationError): + AgentProviderEntityWithPlugin( + identity=mock_provider_identity, + strategies="invalid", + ) + + +# ========================================================= +# Inheritance Smoke Tests +# ========================================================= + + +class TestInheritanceBehavior: + def test_agent_strategy_identity_inherits(self) -> None: + assert issubclass(AgentStrategyIdentity, ToolIdentity) + + def test_agent_strategy_provider_identity_inherits(self) -> None: + assert issubclass(AgentStrategyProviderIdentity, ToolProviderIdentity) diff --git a/api/tests/unit_tests/core/app/apps/__init__.py b/api/tests/unit_tests/core/app/apps/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/__init__.py b/api/tests/unit_tests/core/app/apps/advanced_chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_config_manager.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_config_manager.py new file mode 100644 index 0000000000..6ca4f60459 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_config_manager.py @@ -0,0 +1,75 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from models.model import AppMode + + +class TestAdvancedChatAppConfigManager: + def test_get_app_config(self): + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.ADVANCED_CHAT.value) + workflow = SimpleNamespace(id="wf-1", features_dict={}) + + with ( + patch( + "core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert", + return_value=None, + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.WorkflowVariablesConfigManager.convert", + return_value=[], + ), + ): + app_config = AdvancedChatAppConfigManager.get_app_config(app_model, workflow) + + assert app_config.workflow_id == "wf-1" + assert app_config.app_mode == AppMode.ADVANCED_CHAT + + def test_config_validate_filters_keys(self): + def _add_key(key, value): + def _inner(*args, **kwargs): + config = kwargs.get("config") if kwargs else args[-1] + config = {**config, key: value} + return config, [key] + + return _inner + + with ( + patch( + "core.app.apps.advanced_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults", + side_effect=_add_key("file_upload", 1), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults", + side_effect=_add_key("opening_statement", 2), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults", + side_effect=_add_key("suggested_questions_after_answer", 3), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults", + side_effect=_add_key("speech_to_text", 4), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults", + side_effect=_add_key("text_to_speech", 5), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults", + side_effect=_add_key("retriever_resource", 6), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults", + side_effect=_add_key("sensitive_word_avoidance", 7), + ), + ): + filtered = AdvancedChatAppConfigManager.config_validate(tenant_id="t1", config={}) + + assert filtered["file_upload"] == 1 + assert filtered["opening_statement"] == 2 + assert filtered["suggested_questions_after_answer"] == 3 + assert filtered["speech_to_text"] == 4 + assert filtered["text_to_speech"] == 5 + assert filtered["retriever_resource"] == 6 + assert filtered["sensitive_word_avoidance"] == 7 diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py new file mode 100644 index 0000000000..e2618d960c --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py @@ -0,0 +1,1258 @@ +from __future__ import annotations + +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel, ValidationError + +from constants import UUID_NIL +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator, _refresh_model +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.ops.ops_trace_manager import TraceQueueManager +from models.model import AppMode + + +class TestAdvancedChatAppGeneratorValidation: + def test_generate_requires_query(self): + generator = AdvancedChatAppGenerator() + + with pytest.raises(ValueError, match="query is required"): + generator.generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + args={"inputs": {}}, + invoke_from=InvokeFrom.WEB_APP, + workflow_run_id="run-id", + streaming=False, + ) + + def test_generate_requires_string_query(self): + generator = AdvancedChatAppGenerator() + + with pytest.raises(ValueError, match="query must be a string"): + generator.generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + args={"inputs": {}, "query": 123}, + invoke_from=InvokeFrom.WEB_APP, + workflow_run_id="run-id", + streaming=False, + ) + + def test_single_iteration_generate_validates_args(self): + generator = AdvancedChatAppGenerator() + + with pytest.raises(ValueError, match="node_id is required"): + generator.single_iteration_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="", + user=SimpleNamespace(), + args={"inputs": {}}, + streaming=False, + ) + + with pytest.raises(ValueError, match="inputs is required"): + generator.single_iteration_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="node", + user=SimpleNamespace(), + args={}, + streaming=False, + ) + + def test_single_loop_generate_validates_args(self): + generator = AdvancedChatAppGenerator() + + with pytest.raises(ValueError, match="node_id is required"): + generator.single_loop_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="", + user=SimpleNamespace(), + args=SimpleNamespace(inputs={}), + streaming=False, + ) + + with pytest.raises(ValueError, match="inputs is required"): + generator.single_loop_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="node", + user=SimpleNamespace(), + args=SimpleNamespace(inputs=None), + streaming=False, + ) + + +class TestAdvancedChatAppGeneratorInternals: + @staticmethod + def _build_app_config() -> WorkflowUIBasedAppConfig: + return WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + + def test_generate_loads_conversation_and_files(self, monkeypatch): + generator = AdvancedChatAppGenerator() + app_config = self._build_app_config() + + conversation = SimpleNamespace(id="conversation-id") + built_files: list[object] = [] + build_files_called = {"called": False} + captured: dict[str, object] = {} + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.ConversationService.get_conversation", + lambda **kwargs: conversation, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.FileUploadConfigManager.convert", + lambda *args, **kwargs: {"enabled": True}, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.file_factory.build_from_mappings", + lambda **kwargs: build_files_called.update({"called": True}) or built_files, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda **kwargs: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", lambda **kwargs: SimpleNamespace() + ) + monkeypatch.setattr(generator, "_prepare_user_inputs", lambda **kwargs: kwargs["user_inputs"]) + + DummyTraceQueueManager = type( + "_DummyTraceQueueManager", + (TraceQueueManager,), + { + "__init__": lambda self, app_id=None, user_id=None: ( + setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id) + ) + }, + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.TraceQueueManager", DummyTraceQueueManager) + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + from models import Account + + user = Account(name="Tester", email="tester@example.com") + user.id = "user-id" + + result = generator.generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=SimpleNamespace(features_dict={}), + user=user, + args={ + "query": "hello", + "inputs": {"k": "v"}, + "conversation_id": "conversation-id", + "files": [{"id": "f"}], + }, + invoke_from=InvokeFrom.WEB_APP, + workflow_run_id="run-id", + streaming=False, + ) + + assert result == {"ok": True} + assert captured["conversation"] is conversation + assert captured["application_generate_entity"].files == built_files + assert build_files_called["called"] is True + + def test_resume_delegates_to_generate(self, monkeypatch): + generator = AdvancedChatAppGenerator() + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=self._build_app_config(), + inputs={}, + query="hello", + files=[], + user_id="user", + stream=True, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + captured: dict[str, object] = {} + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"resumed": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + result = generator.resume( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + conversation=SimpleNamespace(id="conversation-id"), + message=SimpleNamespace(id="message-id"), + application_generate_entity=application_generate_entity, + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_runtime_state=SimpleNamespace(), + pause_state_config=None, + ) + + assert result == {"resumed": True} + assert captured["graph_runtime_state"] is not None + + def test_single_iteration_generate_builds_debug_task(self, monkeypatch): + generator = AdvancedChatAppGenerator() + app_config = self._build_app_config() + captured: dict[str, object] = {} + prefill_calls: list[object] = [] + var_loader = SimpleNamespace(loader="draft") + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda **kwargs: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(repo="execution"), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(repo="node"), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.DraftVarLoader", lambda **kwargs: var_loader) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", lambda **kwargs: SimpleNamespace() + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=lambda: SimpleNamespace()), + ) + + class _DraftVarService: + def __init__(self, session): + _ = session + + def prefill_conversation_variable_default_values(self, workflow): + prefill_calls.append(workflow) + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService) + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + result = generator.single_iteration_generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=SimpleNamespace(id="workflow-id"), + node_id="node-1", + user=SimpleNamespace(id="user-id"), + args={"inputs": {"foo": "bar"}}, + streaming=False, + ) + + assert result == {"ok": True} + assert prefill_calls + assert captured["variable_loader"] is var_loader + assert captured["application_generate_entity"].single_iteration_run.node_id == "node-1" + + def test_single_loop_generate_builds_debug_task(self, monkeypatch): + generator = AdvancedChatAppGenerator() + app_config = self._build_app_config() + captured: dict[str, object] = {} + prefill_calls: list[object] = [] + var_loader = SimpleNamespace(loader="draft") + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda **kwargs: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(repo="execution"), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(repo="node"), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.DraftVarLoader", lambda **kwargs: var_loader) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", lambda **kwargs: SimpleNamespace() + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=lambda: SimpleNamespace()), + ) + + class _DraftVarService: + def __init__(self, session): + _ = session + + def prefill_conversation_variable_default_values(self, workflow): + prefill_calls.append(workflow) + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService) + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + result = generator.single_loop_generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=SimpleNamespace(id="workflow-id"), + node_id="node-2", + user=SimpleNamespace(id="user-id"), + args=SimpleNamespace(inputs={"foo": "bar"}), + streaming=False, + ) + + assert result == {"ok": True} + assert prefill_calls + assert captured["variable_loader"] is var_loader + assert captured["application_generate_entity"].single_loop_run.node_id == "node-2" + + def test_generate_internal_flow_initial_conversation_with_pause_layer(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 0 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + conversation = SimpleNamespace(id="conv-1", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) + message = SimpleNamespace(id="msg-1") + db_session = SimpleNamespace(commit=MagicMock(), refresh=MagicMock(), close=MagicMock()) + captured: dict[str, object] = {} + thread_data: dict[str, object] = {} + + monkeypatch.setattr(generator, "_init_generate_records", lambda *args: (conversation, message)) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.get_thread_messages_length", lambda _: 2) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.MessageBasedAppQueueManager", + lambda **kwargs: SimpleNamespace(**kwargs), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.PauseStatePersistenceLayer", + lambda **kwargs: "pause-layer", + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.current_app", + SimpleNamespace(_get_current_object=lambda: SimpleNamespace(name="flask")), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.contextvars.copy_context", lambda: "ctx") + + class _Thread: + def __init__(self, *, target, kwargs): + thread_data["target"] = target + thread_data["kwargs"] = kwargs + + def start(self): + thread_data["started"] = True + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) + ) + monkeypatch.setattr(generator, "_get_draft_var_saver_factory", lambda *args, **kwargs: "draft-factory") + monkeypatch.setattr( + generator, + "_handle_advanced_chat_response", + lambda **kwargs: captured.update(kwargs) or {"raw": True}, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppGenerateResponseConverter.convert", + lambda response, invoke_from: {"response": response, "invoke_from": invoke_from}, + ) + + pause_state_config = SimpleNamespace(session_factory="session-factory", state_owner_user_id="owner") + + response = generator._generate( + workflow=SimpleNamespace(features={"feature": True}), + user=SimpleNamespace(id="user"), + invoke_from=InvokeFrom.WEB_APP, + application_generate_entity=application_generate_entity, + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + conversation=None, + message=None, + stream=False, + pause_state_config=pause_state_config, + ) + + assert response["response"] == {"raw": True} + assert thread_data["started"] is True + assert "pause-layer" in thread_data["kwargs"]["graph_engine_layers"] + assert generator._dialogue_count == 3 + db_session.commit.assert_called_once() + db_session.refresh.assert_called_once_with(conversation) + db_session.close.assert_called_once() + assert captured["draft_var_saver_factory"] == "draft-factory" + + def test_generate_internal_flow_with_existing_records_skips_init(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 0 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + conversation = SimpleNamespace(id="conv-2", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) + message = SimpleNamespace(id="msg-2") + db_session = SimpleNamespace(close=MagicMock(), commit=MagicMock(), refresh=MagicMock()) + init_records = MagicMock() + thread_data: dict[str, object] = {} + + monkeypatch.setattr(generator, "_init_generate_records", init_records) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.get_thread_messages_length", lambda _: 0) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.MessageBasedAppQueueManager", + lambda **kwargs: SimpleNamespace(**kwargs), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.current_app", + SimpleNamespace(_get_current_object=lambda: SimpleNamespace(name="flask")), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.contextvars.copy_context", lambda: "ctx") + + class _Thread: + def __init__(self, *, target, kwargs): + thread_data["target"] = target + thread_data["kwargs"] = kwargs + + def start(self): + thread_data["started"] = True + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) + ) + monkeypatch.setattr(generator, "_get_draft_var_saver_factory", lambda *args, **kwargs: "draft-factory") + monkeypatch.setattr( + generator, + "_handle_advanced_chat_response", + lambda **kwargs: {"raw": True}, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppGenerateResponseConverter.convert", + lambda response, invoke_from: response, + ) + + response = generator._generate( + workflow=SimpleNamespace(features={}), + user=SimpleNamespace(id="user"), + invoke_from=InvokeFrom.WEB_APP, + application_generate_entity=application_generate_entity, + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + assert response == {"raw": True} + init_records.assert_not_called() + assert thread_data["started"] is True + db_session.commit.assert_not_called() + db_session.refresh.assert_not_called() + db_session.close.assert_called_once() + + def test_generate_worker_raises_when_workflow_not_found(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock(return_value=None) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + with pytest.raises(ValueError, match="Workflow not found"): + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=MagicMock(), + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + def test_generate_worker_raises_when_app_not_found_for_internal_call(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="internal-user", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + None, + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + with pytest.raises(ValueError, match="App not found"): + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=MagicMock(), + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + def test_generate_worker_handles_stopped_error(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="internal-user", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + queue_manager = MagicMock() + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + SimpleNamespace(id="app"), + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class _Runner: + def __init__(self, **kwargs): + _ = kwargs + + def run(self): + raise GenerateTaskStoppedError() + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", _Runner) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + queue_manager.publish_error.assert_not_called() + + def test_generate_worker_handles_validation_error(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="internal-user", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + class _ValidationModel(BaseModel): + value: int + + try: + _ValidationModel(value="invalid") + except ValidationError as error: + validation_error = error + else: + raise AssertionError("validation error should be created") + + queue_manager = MagicMock() + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + SimpleNamespace(id="app"), + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class _Runner: + def __init__(self, **kwargs): + _ = kwargs + + def run(self): + raise validation_error + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", _Runner) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + queue_manager.publish_error.assert_called_once() + + def test_generate_worker_handles_value_and_unknown_errors(self, monkeypatch): + app_config = self._build_app_config() + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + def _make_runner(error: Exception): + class _Runner: + def __init__(self, **kwargs): + _ = kwargs + + def run(self): + raise error + + return _Runner + + for raised_error in [ValueError("bad input"), RuntimeError("unexpected")]: + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="internal-user", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + queue_manager = MagicMock() + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + SimpleNamespace(id="app"), + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", + _make_runner(raised_error), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.dify_config", SimpleNamespace(DEBUG=True)) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + queue_manager.publish_error.assert_called_once() + + def test_handle_response_closed_file_raises_stopped(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + class _Pipeline: + def __init__(self, **kwargs) -> None: + _ = kwargs + + def process(self): + raise ValueError("I/O operation on closed file.") + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppGenerateTaskPipeline", + _Pipeline, + ) + + with pytest.raises(GenerateTaskStoppedError): + generator._handle_advanced_chat_response( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(), + queue_manager=SimpleNamespace(), + conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), + message=SimpleNamespace(id="msg"), + user=SimpleNamespace(), + draft_var_saver_factory=lambda **kwargs: None, + stream=False, + ) + + def test_handle_response_re_raises_value_error(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + class _Pipeline: + def __init__(self, **kwargs): + _ = kwargs + + def process(self): + raise ValueError("other error") + + logger_exception = MagicMock() + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.logger.exception", logger_exception) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppGenerateTaskPipeline", _Pipeline) + + with pytest.raises(ValueError, match="other error"): + generator._handle_advanced_chat_response( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(), + queue_manager=SimpleNamespace(), + conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), + message=SimpleNamespace(id="msg"), + user=SimpleNamespace(), + draft_var_saver_factory=lambda **kwargs: None, + stream=False, + ) + + logger_exception.assert_called_once() + + def test_refresh_model_returns_detached_model(self, monkeypatch): + source_model = SimpleNamespace(id="source-id") + detached_model = SimpleNamespace(id="source-id", detached=True) + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def get(self, model_type, model_id): + _ = model_type + return detached_model if model_id == "source-id" else None + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object())) + + refreshed = _refresh_model(session=SimpleNamespace(), model=source_model) + + assert refreshed is detached_model + + def test_generate_worker_handles_invoke_auth_error(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="end-user-id", + stream=False, + invoke_from=InvokeFrom.SERVICE_API, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + queue_manager = MagicMock() + + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT)) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + class _Runner: + def __init__(self, **kwargs) -> None: + _ = kwargs + + def run(self): + from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError + + raise InvokeAuthorizationError("bad key") + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", _Runner) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + SimpleNamespace(id="end-user-id", session_id="session-id"), + SimpleNamespace(id="app"), + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + assert queue_manager.publish_error.called + + def test_generate_debugger_enables_retrieve_source(self, monkeypatch): + generator = AdvancedChatAppGenerator() + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda app_model, workflow: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.FileUploadConfigManager.convert", + lambda features_dict, is_vision=False: None, + ) + DummyTraceQueueManager = type( + "_DummyTraceQueueManager", + (TraceQueueManager,), + { + "__init__": lambda self, app_id=None, user_id=None: ( + setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id) + ) + }, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.TraceQueueManager", + DummyTraceQueueManager, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", + lambda **kwargs: SimpleNamespace(), + ) + + captured = {} + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + app_model = SimpleNamespace(id="app", tenant_id="tenant") + workflow = SimpleNamespace(features_dict={}) + from models import Account + + user = Account(name="Tester", email="tester@example.com") + user.id = "user" + + result = generator.generate( + app_model=app_model, + workflow=workflow, + user=user, + args={"query": "hello\x00", "inputs": {}}, + invoke_from=InvokeFrom.DEBUGGER, + workflow_run_id="run-id", + streaming=False, + ) + + assert result == {"ok": True} + assert app_config.additional_features.show_retrieve_source is True + assert captured["application_generate_entity"].query == "hello" + + def test_generate_service_api_sets_parent_message_id(self, monkeypatch): + generator = AdvancedChatAppGenerator() + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda app_model, workflow: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.FileUploadConfigManager.convert", + lambda features_dict, is_vision=False: None, + ) + DummyTraceQueueManager = type( + "_DummyTraceQueueManager", + (TraceQueueManager,), + { + "__init__": lambda self, app_id=None, user_id=None: ( + setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id) + ) + }, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.TraceQueueManager", + DummyTraceQueueManager, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", + lambda **kwargs: SimpleNamespace(), + ) + + captured = {} + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + app_model = SimpleNamespace(id="app", tenant_id="tenant") + workflow = SimpleNamespace(features_dict={}) + from models.model import EndUser + + user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session") + user.id = "end-user" + + generator.generate( + app_model=app_model, + workflow=workflow, + user=user, + args={"query": "hello", "inputs": {}, "parent_message_id": "p1"}, + invoke_from=InvokeFrom.SERVICE_API, + workflow_run_id="run-id", + streaming=False, + ) + + assert captured["application_generate_entity"].parent_message_id == UUID_NIL diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py new file mode 100644 index 0000000000..5b199e0c52 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -0,0 +1,96 @@ +from collections.abc import Generator + +from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, +) +from dify_graph.enums import WorkflowNodeExecutionStatus + + +class TestAdvancedChatGenerateResponseConverter: + def test_blocking_simple_response_metadata(self): + data = ChatbotAppBlockingResponse.Data( + id="msg-1", + mode="chat", + conversation_id="c1", + message_id="m1", + answer="hi", + metadata={"usage": {"total_tokens": 1}}, + created_at=1, + ) + blocking = ChatbotAppBlockingResponse(task_id="t1", data=data) + response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + assert "usage" not in response["metadata"] + + def test_stream_simple_response_includes_node_events(self): + node_start = NodeStartStreamResponse( + task_id="t1", + workflow_run_id="r1", + data=NodeStartStreamResponse.Data( + id="e1", + node_id="n1", + node_type="answer", + title="Answer", + index=1, + created_at=1, + ), + ) + node_finish = NodeFinishStreamResponse( + task_id="t1", + workflow_run_id="r1", + data=NodeFinishStreamResponse.Data( + id="e1", + node_id="n1", + node_type="answer", + title="Answer", + index=1, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + elapsed_time=0.1, + created_at=1, + finished_at=2, + ), + ) + + def stream() -> Generator[ChatbotAppStreamResponse, None, None]: + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=PingStreamResponse(task_id="t1"), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=node_start, + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=node_finish, + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=MessageEndStreamResponse(task_id="t1", id="m1"), + ) + + converted = list(AdvancedChatAppGenerateResponseConverter.convert_stream_simple_response(stream())) + assert converted[0] == "ping" + assert converted[1]["event"] == "node_started" + assert converted[2]["event"] == "node_finished" + assert converted[3]["event"] == "error" diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py similarity index 100% rename from api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py rename to api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py new file mode 100644 index 0000000000..67f87710a1 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -0,0 +1,600 @@ +from __future__ import annotations + +from contextlib import contextmanager +from datetime import datetime +from types import SimpleNamespace + +import pytest + +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import ( + QueueAdvancedChatMessageEndEvent, + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueHumanInputFormFilledEvent, + QueueHumanInputFormTimeoutEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueLoopCompletedEvent, + QueueLoopNextEvent, + QueueLoopStartEvent, + QueueMessageReplaceEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.app.entities.task_entities import ( + AnnotationReply, + AnnotationReplyAccount, + MessageAudioStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) +from core.base.tts.app_generator_tts_publisher import AudioTrunk +from dify_graph.enums import NodeType +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from models.enums import MessageStatus +from models.model import AppMode, EndUser + + +def _make_pipeline(): + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + message = SimpleNamespace( + id="message-id", + query="hello", + created_at=datetime.utcnow(), + status=MessageStatus.NORMAL, + answer="", + ) + conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT) + workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session") + + pipeline = AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None), + conversation=conversation, + message=message, + user=user, + stream=False, + dialogue_count=1, + draft_var_saver_factory=lambda **kwargs: None, + ) + + return pipeline + + +class TestAdvancedChatGenerateTaskPipeline: + def test_ensure_workflow_initialized_raises(self): + pipeline = _make_pipeline() + + with pytest.raises(ValueError, match="workflow run not initialized"): + pipeline._ensure_workflow_initialized() + + def test_to_blocking_response_returns_message_end(self): + pipeline = _make_pipeline() + pipeline._task_state.answer = "done" + + def _gen(): + yield MessageEndStreamResponse(task_id="task", id="message-id", metadata={"k": "v"}) + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.answer == "done" + assert response.data.metadata == {"k": "v"} + + def test_handle_text_chunk_event_updates_state(self): + pipeline = _make_pipeline() + pipeline._message_cycle_manager = SimpleNamespace( + message_to_stream_response=lambda **kwargs: MessageEndStreamResponse( + task_id="task", id="message-id", metadata={} + ) + ) + + event = SimpleNamespace(text="hi", from_variable_selector=None) + + responses = list(pipeline._handle_text_chunk_event(event)) + + assert pipeline._task_state.answer == "hi" + assert responses + + def test_listen_audio_msg_returns_audio_stream(self): + pipeline = _make_pipeline() + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data")) + + response = pipeline._listen_audio_msg(publisher=publisher, task_id="task") + + assert isinstance(response, MessageAudioStreamResponse) + + def test_handle_ping_event(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task") + + responses = list(pipeline._handle_ping_event(QueuePingEvent())) + + assert isinstance(responses[0], PingStreamResponse) + + def test_handle_error_event(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline._base_task_pipeline.error_to_stream_response = lambda err: err + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + pipeline._database_session = _fake_session + + responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom")))) + + assert isinstance(responses[0], ValueError) + + def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + monkeypatch.setattr(pipeline, "_get_message", lambda **kwargs: SimpleNamespace()) + + responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent())) + + assert pipeline._workflow_run_id == "run-id" + assert responses == ["started"] + + def test_message_end_to_stream_response_strips_annotation_reply(self): + pipeline = _make_pipeline() + pipeline._task_state.metadata.annotation_reply = AnnotationReply( + id="ann", + account=AnnotationReplyAccount(id="acc", name="acc"), + ) + + response = pipeline._message_end_to_stream_response() + + assert "annotation_reply" not in response.metadata + + def test_handle_output_moderation_chunk_publishes_stop(self): + pipeline = _make_pipeline() + events: list[object] = [] + + class _Moderation: + def should_direct_output(self): + return True + + def get_final_output(self): + return "final" + + pipeline._base_task_pipeline.output_moderation_handler = _Moderation() + pipeline._base_task_pipeline.queue_manager = SimpleNamespace( + publish=lambda event, pub_from: events.append(event) + ) + + result = pipeline._handle_output_moderation_chunk("ignored") + + assert result is True + assert pipeline._task_state.answer == "final" + assert any(isinstance(event, QueueTextChunkEvent) for event in events) + assert any(isinstance(event, QueueStopEvent) for event in events) + + def test_handle_node_succeeded_event_records_files(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.fetch_files_from_node_outputs = lambda outputs: [ + {"type": "file", "transfer_method": "local"} + ] + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done" + pipeline._save_output_for_event = lambda event, node_execution_id: None + + event = SimpleNamespace( + node_type=NodeType.ANSWER, + outputs={"k": "v"}, + node_execution_id="exec", + node_id="node", + ) + + responses = list(pipeline._handle_node_succeeded_event(event)) + + assert responses == ["done"] + assert pipeline._recorded_files + + def test_iteration_and_loop_handlers(self): + pipeline = _make_pipeline() + pipeline._workflow_run_id = "run-id" + pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: ( + "iter_start" + ) + pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next" + pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: ( + "iter_done" + ) + pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start" + pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next" + pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done" + + iter_start = QueueIterationStartEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + iter_next = QueueIterationNextEvent( + index=1, + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + node_run_index=1, + ) + iter_done = QueueIterationCompletedEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + loop_start = QueueLoopStartEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + loop_next = QueueLoopNextEvent( + index=1, + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + node_run_index=1, + ) + loop_done = QueueLoopCompletedEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + + assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter_start"] + assert list(pipeline._handle_iteration_next_event(iter_next)) == ["iter_next"] + assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["iter_done"] + assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop_start"] + assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"] + assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"] + + def test_workflow_finish_handlers(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._workflow_run_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" + pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: ["pause"] + pipeline._persist_human_input_extra_content = lambda **kwargs: None + pipeline._save_message = lambda **kwargs: None + pipeline._base_task_pipeline.queue_manager.publish = lambda *args, **kwargs: None + pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline._base_task_pipeline.error_to_stream_response = lambda err: err + pipeline._get_message = lambda **kwargs: SimpleNamespace(id="message-id") + + @contextmanager + def _fake_session(): + yield SimpleNamespace(scalar=lambda *args, **kwargs: None) + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + + succeeded_responses = list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={}))) + assert len(succeeded_responses) == 2 + assert isinstance(succeeded_responses[0], MessageEndStreamResponse) + assert succeeded_responses[1] == "finish" + + partial_success_responses = list( + pipeline._handle_workflow_partial_success_event( + QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={}) + ) + ) + assert len(partial_success_responses) == 2 + assert isinstance(partial_success_responses[0], MessageEndStreamResponse) + assert partial_success_responses[1] == "finish" + assert ( + list(pipeline._handle_workflow_failed_event(QueueWorkflowFailedEvent(error="err", exceptions_count=1)))[0] + == "finish" + ) + assert list(pipeline._handle_workflow_paused_event(QueueWorkflowPausedEvent(reasons=[], outputs={}))) == [ + "pause" + ] + + def test_node_failure_handlers(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "node_finish" + pipeline._save_output_for_event = lambda event, node_execution_id: None + + failed_event = QueueNodeFailedEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="err", + ) + exc_event = QueueNodeExceptionEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="err", + ) + + assert list(pipeline._handle_node_failed_events(failed_event)) == ["node_finish"] + assert list(pipeline._handle_node_failed_events(exc_event)) == ["node_finish"] + + def test_handle_text_chunk_event_tracks_streaming_metrics(self): + pipeline = _make_pipeline() + published: list[object] = [] + + class _Publisher: + def publish(self, message): + published.append(message) + + pipeline._message_cycle_manager = SimpleNamespace(message_to_stream_response=lambda **kwargs: "chunk") + + event = SimpleNamespace(text="hi", from_variable_selector=["a"]) + queue_message = SimpleNamespace(event=event) + + responses = list( + pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message) + ) + + assert responses == ["chunk"] + assert pipeline._task_state.is_streaming_response is True + assert pipeline._task_state.first_token_time is not None + assert pipeline._task_state.last_token_time is not None + assert pipeline._task_state.answer == "hi" + assert published == [queue_message] + + def test_handle_output_moderation_chunk_appends_token(self): + pipeline = _make_pipeline() + seen: list[str] = [] + + class _Moderation: + def should_direct_output(self): + return False + + def append_new_token(self, text): + seen.append(text) + + pipeline._base_task_pipeline.output_moderation_handler = _Moderation() + + result = pipeline._handle_output_moderation_chunk("token") + + assert result is False + assert seen == ["token"] + + def test_handle_retriever_and_annotation_events(self): + pipeline = _make_pipeline() + calls = {"retriever": 0, "annotation": 0} + + def _hit_retriever(event): + calls["retriever"] += 1 + + def _hit_annotation(event): + calls["annotation"] += 1 + + pipeline._message_cycle_manager.handle_retriever_resources = _hit_retriever + pipeline._message_cycle_manager.handle_annotation_reply = _hit_annotation + + retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[]) + annotation_event = QueueAnnotationReplyEvent(message_annotation_id="ann") + + assert list(pipeline._handle_retriever_resources_event(retriever_event)) == [] + assert list(pipeline._handle_annotation_reply_event(annotation_event)) == [] + assert calls == {"retriever": 1, "annotation": 1} + + def test_handle_message_replace_event(self): + pipeline = _make_pipeline() + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace" + + event = QueueMessageReplaceEvent( + text="new", + reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, + ) + + assert list(pipeline._handle_message_replace_event(event)) == ["replace"] + + def test_handle_human_input_events(self): + pipeline = _make_pipeline() + persisted: list[str] = [] + pipeline._persist_human_input_extra_content = lambda **kwargs: persisted.append("saved") + pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled" + pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout" + + filled_event = QueueHumanInputFormFilledEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="title", + rendered_content="content", + action_id="action", + action_text="action", + ) + timeout_event = QueueHumanInputFormTimeoutEvent( + node_id="node", + node_type=NodeType.LLM, + node_title="title", + expiration_time=datetime.utcnow(), + ) + + assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"] + assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"] + assert persisted == ["saved"] + + def test_save_message_strips_markdown_and_sets_usage(self): + pipeline = _make_pipeline() + pipeline._recorded_files = [ + { + "type": "image", + "transfer_method": "remote", + "remote_url": "http://example.com/file.png", + "related_id": "file-id", + } + ] + pipeline._task_state.answer = "![img](url) hello" + pipeline._task_state.is_streaming_response = True + pipeline._task_state.first_token_time = pipeline._base_task_pipeline.start_at + 0.1 + pipeline._task_state.last_token_time = pipeline._base_task_pipeline.start_at + 0.2 + + message = SimpleNamespace( + id="message-id", + status=MessageStatus.PAUSED, + answer="", + updated_at=None, + provider_response_latency=None, + message_tokens=None, + message_unit_price=None, + message_price_unit=None, + answer_tokens=None, + answer_unit_price=None, + answer_price_unit=None, + total_price=None, + currency=None, + message_metadata=None, + invoke_from=InvokeFrom.WEB_APP, + from_account_id=None, + from_end_user_id="end-user", + ) + + class _Session: + def scalar(self, *args, **kwargs): + return message + + def add_all(self, items): + self.items = items + + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + + pipeline._save_message(session=_Session(), graph_runtime_state=graph_runtime_state) + + assert message.status == MessageStatus.NORMAL + assert message.answer == "hello" + assert message.message_metadata + + def test_handle_stop_event_saves_message_for_moderation(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._message_end_to_stream_response = lambda: "end" + saved: list[str] = [] + + def _save_message(**kwargs): + saved.append("saved") + + pipeline._save_message = _save_message + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + + responses = list(pipeline._handle_stop_event(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION))) + + assert responses == ["end"] + assert saved == ["saved"] + + def test_handle_message_end_event_applies_output_moderation(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe" + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace" + pipeline._message_end_to_stream_response = lambda: "end" + + saved: list[str] = [] + + def _save_message(**kwargs): + saved.append("saved") + + pipeline._save_message = _save_message + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + + responses = list(pipeline._handle_advanced_chat_message_end_event(QueueAdvancedChatMessageEndEvent())) + + assert responses == ["replace", "end"] + assert saved == ["saved"] + + def test_dispatch_event_handles_node_exception(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed" + pipeline._save_output_for_event = lambda *args, **kwargs: None + + event = QueueNodeExceptionEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="err", + ) + + assert list(pipeline._dispatch_event(event)) == ["failed"] diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_config_manager.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_config_manager.py new file mode 100644 index 0000000000..a871e8d93b --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_config_manager.py @@ -0,0 +1,302 @@ +import uuid +from types import SimpleNamespace + +import pytest + +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.apps.agent_chat.app_config_manager import ( + AgentChatAppConfigManager, +) +from core.entities.agent_entities import PlanningStrategy + + +class TestAgentChatAppConfigManagerGetAppConfig: + def test_get_app_config_override_config(self, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"ignored": True} + + override_config = {"model": {"provider": "p"}} + + mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert") + mocker.patch.object(AgentChatAppConfigManager, "convert_features") + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert", + return_value=("variables", "external"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig", + side_effect=lambda **kwargs: SimpleNamespace(**kwargs), + ) + + result = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=None, + override_config_dict=override_config, + ) + + assert result.app_model_config_dict == override_config + assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS + assert result.variables == "variables" + assert result.external_data_variables == "external" + + def test_get_app_config_conversation_specific(self, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "p"}} + conversation = mocker.MagicMock() + + mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert") + mocker.patch.object(AgentChatAppConfigManager, "convert_features") + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert", + return_value=("variables", "external"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig", + side_effect=lambda **kwargs: SimpleNamespace(**kwargs), + ) + + result = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=None, + ) + + assert result.app_model_config_dict == app_model_config.to_dict.return_value + assert result.app_model_config_from.value == "conversation-specific-config" + + def test_get_app_config_latest_config(self, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "p"}} + + mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert") + mocker.patch.object(AgentChatAppConfigManager, "convert_features") + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert", + return_value=("variables", "external"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig", + side_effect=lambda **kwargs: SimpleNamespace(**kwargs), + ) + + result = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=None, + override_config_dict=None, + ) + + assert result.app_model_config_from.value == "app-latest-config" + + +class TestAgentChatAppConfigManagerConfigValidate: + def test_config_validate_filters_related_keys(self, mocker): + config = { + "model": {}, + "user_input_form": {}, + "file_upload": {}, + "prompt_template": {}, + "agent_mode": {}, + "opening_statement": {}, + "suggested_questions_after_answer": {}, + "speech_to_text": {}, + "text_to_speech": {}, + "retriever_resource": {}, + "dataset": {}, + "moderation": {}, + "extra": "value", + } + + def return_with_key(key): + return config, [key] + + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.ModelConfigManager.validate_and_set_defaults", + side_effect=lambda tenant_id, cfg: return_with_key("model"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults", + side_effect=lambda tenant_id, cfg: return_with_key("user_input_form"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("file_upload"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults", + side_effect=lambda app_mode, cfg: return_with_key("prompt_template"), + ) + mocker.patch.object( + AgentChatAppConfigManager, + "validate_agent_mode_and_set_defaults", + side_effect=lambda tenant_id, cfg: return_with_key("agent_mode"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("opening_statement"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("suggested_questions_after_answer"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("speech_to_text"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("text_to_speech"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("retriever_resource"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults", + side_effect=lambda tenant_id, app_mode, cfg: return_with_key("dataset"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults", + side_effect=lambda tenant_id, cfg: return_with_key("moderation"), + ) + + filtered = AgentChatAppConfigManager.config_validate("tenant", config) + assert set(filtered.keys()) == { + "model", + "user_input_form", + "file_upload", + "prompt_template", + "agent_mode", + "opening_statement", + "suggested_questions_after_answer", + "speech_to_text", + "text_to_speech", + "retriever_resource", + "dataset", + "moderation", + } + assert "extra" not in filtered + + +class TestValidateAgentModeAndSetDefaults: + def test_defaults_when_missing(self): + config = {} + updated, keys = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config) + assert "agent_mode" in updated + assert updated["agent_mode"]["enabled"] is False + assert updated["agent_mode"]["tools"] == [] + assert keys == ["agent_mode"] + + @pytest.mark.parametrize( + "agent_mode", + ["invalid", 123], + ) + def test_agent_mode_type_validation(self, agent_mode): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": agent_mode}) + + def test_agent_mode_empty_list_defaults(self): + config = {"agent_mode": []} + updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config) + assert updated["agent_mode"]["enabled"] is False + assert updated["agent_mode"]["tools"] == [] + + def test_enabled_must_be_bool(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": {"enabled": "yes"}}) + + def test_strategy_must_be_valid(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", {"agent_mode": {"enabled": True, "strategy": "invalid"}} + ) + + def test_tools_must_be_list(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", {"agent_mode": {"enabled": True, "tools": "not-list"}} + ) + + def test_old_tool_dataset_requires_id(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True}}]}} + ) + + def test_old_tool_dataset_id_must_be_uuid(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", + {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": "bad"}}]}}, + ) + + def test_old_tool_dataset_id_not_exists(self, mocker): + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists", + return_value=False, + ) + dataset_id = str(uuid.uuid4()) + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", + {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": dataset_id}}]}}, + ) + + def test_old_tool_enabled_must_be_bool(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", + {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": "yes", "id": str(uuid.uuid4())}}]}}, + ) + + @pytest.mark.parametrize("missing_key", ["provider_type", "provider_id", "tool_name", "tool_parameters"]) + def test_new_style_tool_requires_fields(self, missing_key): + tool = {"enabled": True, "provider_type": "type", "provider_id": "id", "tool_name": "tool"} + tool.pop(missing_key, None) + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", {"agent_mode": {"enabled": True, "tools": [tool]}} + ) + + def test_valid_old_and_new_style_tools(self, mocker): + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists", + return_value=True, + ) + dataset_id = str(uuid.uuid4()) + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER.value, + "tools": [ + {"dataset": {"id": dataset_id}}, + { + "provider_type": "builtin", + "provider_id": "p1", + "tool_name": "tool", + "tool_parameters": {}, + }, + ], + } + } + + updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config) + assert updated["agent_mode"]["tools"][0]["dataset"]["enabled"] is False + assert updated["agent_mode"]["tools"][1]["enabled"] is False diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py new file mode 100644 index 0000000000..53f26d1592 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -0,0 +1,296 @@ +import contextlib + +import pytest +from pydantic import ValidationError + +from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError + + +class DummyAccount: + def __init__(self, user_id): + self.id = user_id + self.session_id = f"session-{user_id}" + + +@pytest.fixture +def generator(mocker): + gen = AgentChatAppGenerator() + mocker.patch( + "core.app.apps.agent_chat.app_generator.current_app", + new=mocker.MagicMock(_get_current_object=mocker.MagicMock()), + ) + mocker.patch("core.app.apps.agent_chat.app_generator.contextvars.copy_context", return_value="ctx") + return gen + + +class TestAgentChatAppGeneratorGenerate: + def test_generate_rejects_blocking_mode(self, generator, mocker): + app_model = mocker.MagicMock() + user = DummyAccount("user") + with pytest.raises(ValueError): + generator.generate(app_model=app_model, user=user, args={}, invoke_from=mocker.MagicMock(), streaming=False) + + def test_generate_requires_query(self, generator, mocker): + app_model = mocker.MagicMock() + user = DummyAccount("user") + with pytest.raises(ValueError): + generator.generate(app_model=app_model, user=user, args={"inputs": {}}, invoke_from=mocker.MagicMock()) + + def test_generate_rejects_non_string_query(self, generator, mocker): + app_model = mocker.MagicMock() + user = DummyAccount("user") + with pytest.raises(ValueError): + generator.generate( + app_model=app_model, + user=user, + args={"query": 123, "inputs": {}}, + invoke_from=mocker.MagicMock(), + ) + + def test_generate_override_requires_debugger(self, generator, mocker): + app_model = mocker.MagicMock() + user = DummyAccount("user") + + with pytest.raises(ValueError): + generator.generate( + app_model=app_model, + user=user, + args={"query": "hi", "inputs": {}, "model_config": {"model": {"provider": "p"}}}, + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_success_with_debugger_override(self, generator, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "p"}} + + user = DummyAccount("user") + invoke_from = InvokeFrom.DEBUGGER + + generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config) + generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1}) + generator._init_generate_records = mocker.MagicMock( + return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg")) + ) + generator._handle_response = mocker.MagicMock(return_value="response") + + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.config_validate", + return_value={"validated": True}, + ) + app_config = mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[]) + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config", + return_value=app_config, + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert", + return_value=mocker.MagicMock(), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert", + return_value=mocker.MagicMock(), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings", + return_value=["file-obj"], + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.ConversationService.get_conversation", + return_value=mocker.MagicMock(id="conv"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.TraceQueueManager", + return_value=mocker.MagicMock(), + ) + + queue_manager = mocker.MagicMock() + mocker.patch( + "core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager", + return_value=queue_manager, + ) + + thread_obj = mocker.MagicMock() + mocker.patch( + "core.app.apps.agent_chat.app_generator.threading.Thread", + return_value=thread_obj, + ) + + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert", + return_value={"result": "ok"}, + ) + app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=invoke_from) + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity", + return_value=app_entity, + ) + + args = { + "query": "hello", + "inputs": {"name": "world"}, + "conversation_id": "conv", + "model_config": {"model": {"provider": "p"}}, + "files": [{"id": "f1"}], + } + + result = generator.generate(app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=True) + + assert result == {"result": "ok"} + thread_obj.start.assert_called_once() + + def test_generate_without_file_config(self, generator, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "p"}} + + user = DummyAccount("user") + + generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config) + generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1}) + generator._init_generate_records = mocker.MagicMock( + return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg")) + ) + generator._handle_response = mocker.MagicMock(return_value="response") + + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config", + return_value=mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[]), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert", + return_value=mocker.MagicMock(), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert", + return_value=None, + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings", + return_value=["file-obj"], + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.TraceQueueManager", + return_value=mocker.MagicMock(), + ) + + mocker.patch( + "core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager", + return_value=mocker.MagicMock(), + ) + + thread_obj = mocker.MagicMock() + mocker.patch( + "core.app.apps.agent_chat.app_generator.threading.Thread", + return_value=thread_obj, + ) + + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert", + return_value={"result": "ok"}, + ) + app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=InvokeFrom.WEB_APP) + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity", + return_value=app_entity, + ) + + args = {"query": "hello", "inputs": {"name": "world"}} + + result = generator.generate( + app_model=app_model, + user=user, + args=args, + invoke_from=InvokeFrom.WEB_APP, + streaming=True, + ) + + assert result == {"result": "ok"} + + +class TestAgentChatAppGeneratorWorker: + @pytest.fixture(autouse=True) + def patch_context(self, mocker): + @contextlib.contextmanager + def ctx_manager(*args, **kwargs): + yield + + mocker.patch("core.app.apps.agent_chat.app_generator.preserve_flask_contexts", ctx_manager) + + def test_generate_worker_handles_generate_task_stopped(self, generator, mocker): + queue_manager = mocker.MagicMock() + generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock()) + generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock()) + + runner = mocker.MagicMock() + runner.run.side_effect = GenerateTaskStoppedError() + mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner) + mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close") + + generator._generate_worker( + flask_app=mocker.MagicMock(), + context=mocker.MagicMock(), + application_generate_entity=mocker.MagicMock(), + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + ) + + queue_manager.publish_error.assert_not_called() + + @pytest.mark.parametrize( + "error", + [ + InvokeAuthorizationError("bad"), + ValidationError.from_exception_data("TestModel", []), + ValueError("bad"), + Exception("bad"), + ], + ) + def test_generate_worker_publishes_errors(self, generator, mocker, error): + queue_manager = mocker.MagicMock() + generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock()) + generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock()) + + runner = mocker.MagicMock() + runner.run.side_effect = error + mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner) + mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close") + + generator._generate_worker( + flask_app=mocker.MagicMock(), + context=mocker.MagicMock(), + application_generate_entity=mocker.MagicMock(), + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + ) + + assert queue_manager.publish_error.called + + def test_generate_worker_logs_value_error_when_debug(self, generator, mocker): + queue_manager = mocker.MagicMock() + generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock()) + generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock()) + + runner = mocker.MagicMock() + runner.run.side_effect = ValueError("bad") + mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner) + mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close") + + mocker.patch("core.app.apps.agent_chat.app_generator.dify_config", new=mocker.MagicMock(DEBUG=True)) + logger = mocker.patch("core.app.apps.agent_chat.app_generator.logger") + + generator._generate_worker( + flask_app=mocker.MagicMock(), + context=mocker.MagicMock(), + application_generate_entity=mocker.MagicMock(), + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + ) + + logger.exception.assert_called_once() diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py new file mode 100644 index 0000000000..5603115b30 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -0,0 +1,413 @@ +import pytest + +from core.agent.entities import AgentEntity +from core.app.apps.agent_chat.app_runner import AgentChatAppRunner +from core.moderation.base import ModerationError +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey + + +@pytest.fixture +def runner(): + return AgentChatAppRunner() + + +class TestAgentChatAppRunnerRun: + def test_run_app_not_found(self, runner, mocker): + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock()) + generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) + + def test_run_moderation_error_direct_output(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = mocker.MagicMock() + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock(), + conversation_id=None, + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad")) + mocker.patch.object(runner, "direct_output") + + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) + + runner.direct_output.assert_called_once() + + def test_run_annotation_reply_short_circuits(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = mocker.MagicMock() + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock(), + conversation_id=None, + user_id="user", + invoke_from=mocker.MagicMock(), + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + annotation = mocker.MagicMock(id="anno", content="answer") + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=annotation) + mocker.patch.object(runner, "direct_output") + + queue_manager = mocker.MagicMock() + runner.run(generate_entity, queue_manager, mocker.MagicMock(), mocker.MagicMock()) + + queue_manager.publish.assert_called_once() + runner.direct_output.assert_called_once() + + def test_run_hosting_moderation_short_circuits(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = mocker.MagicMock() + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock(), + conversation_id=None, + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=True) + + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) + + def test_run_model_schema_missing(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = None + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) + + @pytest.mark.parametrize( + ("mode", "expected_runner"), + [ + (LLMMode.CHAT, "CotChatAgentRunner"), + (LLMMode.COMPLETION, "CotCompletionAgentRunner"), + ], + ) + def test_run_chain_of_thought_modes(self, runner, mocker, mode, expected_runner): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + model_schema = mocker.MagicMock() + model_schema.features = [] + model_schema.model_properties = {ModelPropertyKey.MODE: mode} + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = model_schema + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + conversation = mocker.MagicMock(id="conv") + message = mocker.MagicMock(id="msg") + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, conversation, message], + ) + + runner_cls = mocker.MagicMock() + mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls) + + runner_instance = mocker.MagicMock() + runner_cls.return_value = runner_instance + runner_instance.run.return_value = [] + mocker.patch.object(runner, "_handle_invoke_result") + + runner.run(generate_entity, mocker.MagicMock(), conversation, message) + + runner_instance.run.assert_called_once() + runner._handle_invoke_result.assert_called_once() + + def test_run_invalid_llm_mode_raises(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + model_schema = mocker.MagicMock() + model_schema.features = [] + model_schema.model_properties = {ModelPropertyKey.MODE: "invalid"} + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = model_schema + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + conversation = mocker.MagicMock(id="conv") + message = mocker.MagicMock(id="msg") + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, conversation, message], + ) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), conversation, message) + + def test_run_function_calling_strategy_selected_by_features(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + model_schema = mocker.MagicMock() + model_schema.features = [ModelFeature.TOOL_CALL] + model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT} + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = model_schema + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + conversation = mocker.MagicMock(id="conv") + message = mocker.MagicMock(id="msg") + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, conversation, message], + ) + + runner_cls = mocker.MagicMock() + mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls) + + runner_instance = mocker.MagicMock() + runner_cls.return_value = runner_instance + runner_instance.run.return_value = [] + mocker.patch.object(runner, "_handle_invoke_result") + + runner.run(generate_entity, mocker.MagicMock(), conversation, message) + + assert app_config.agent.strategy == AgentEntity.Strategy.FUNCTION_CALLING + runner_instance.run.assert_called_once() + + def test_run_conversation_not_found(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, None], + ) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg")) + + def test_run_message_not_found(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, mocker.MagicMock(id="conv"), None], + ) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg")) + + def test_run_invalid_agent_strategy_raises(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m") + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + model_schema = mocker.MagicMock() + model_schema.features = [] + model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT} + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = model_schema + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + conversation = mocker.MagicMock(id="conv") + message = mocker.MagicMock(id="msg") + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, conversation, message], + ) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), conversation, message) diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_generate_response_converter.py new file mode 100644 index 0000000000..02a1e04c98 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_generate_response_converter.py @@ -0,0 +1,162 @@ +from collections.abc import Generator + +from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + MessageStreamResponse, + PingStreamResponse, +) + + +class TestAgentChatAppGenerateResponseConverterBlocking: + def test_convert_blocking_full_response(self): + blocking = ChatbotAppBlockingResponse( + task_id="task", + data=ChatbotAppBlockingResponse.Data( + id="id", + mode="agent-chat", + conversation_id="conv", + message_id="msg", + answer="answer", + metadata={"a": 1}, + created_at=123, + ), + ) + + result = AgentChatAppGenerateResponseConverter.convert_blocking_full_response(blocking) + + assert result["event"] == "message" + assert result["answer"] == "answer" + assert result["metadata"] == {"a": 1} + + def test_convert_blocking_simple_response_with_dict_metadata(self): + blocking = ChatbotAppBlockingResponse( + task_id="task", + data=ChatbotAppBlockingResponse.Data( + id="id", + mode="agent-chat", + conversation_id="conv", + message_id="msg", + answer="answer", + metadata={ + "retriever_resources": [ + { + "segment_id": "s1", + "position": 1, + "document_name": "doc", + "score": 0.9, + "content": "content", + } + ], + "annotation_reply": {"id": "a"}, + "usage": {"prompt_tokens": 1}, + }, + created_at=123, + ), + ) + + result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert "annotation_reply" not in result["metadata"] + assert "usage" not in result["metadata"] + + def test_convert_blocking_simple_response_with_non_dict_metadata(self): + blocking = ChatbotAppBlockingResponse.model_construct( + task_id="task", + data=ChatbotAppBlockingResponse.Data.model_construct( + id="id", + mode="agent-chat", + conversation_id="conv", + message_id="msg", + answer="answer", + metadata="bad", + created_at=123, + ), + ) + + result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert result["metadata"] == {} + + +class TestAgentChatAppGenerateResponseConverterStream: + def build_stream(self) -> Generator[ChatbotAppStreamResponse, None, None]: + def _gen(): + yield ChatbotAppStreamResponse( + conversation_id="conv", + message_id="msg", + created_at=1, + stream_response=PingStreamResponse(task_id="t"), + ) + yield ChatbotAppStreamResponse( + conversation_id="conv", + message_id="msg", + created_at=2, + stream_response=MessageStreamResponse(task_id="t", id="m1", answer="hi"), + ) + yield ChatbotAppStreamResponse( + conversation_id="conv", + message_id="msg", + created_at=3, + stream_response=MessageEndStreamResponse( + task_id="t", + id="m1", + metadata={ + "retriever_resources": [ + { + "segment_id": "s1", + "position": 1, + "document_name": "doc", + "score": 0.9, + "content": "content", + "summary": "summary", + "extra": "ignored", + } + ], + "annotation_reply": {"id": "a"}, + "usage": {"prompt_tokens": 1}, + }, + ), + ) + yield ChatbotAppStreamResponse( + conversation_id="conv", + message_id="msg", + created_at=4, + stream_response=ErrorStreamResponse(task_id="t", err=RuntimeError("bad")), + ) + + return _gen() + + def test_convert_stream_full_response(self): + items = list(AgentChatAppGenerateResponseConverter.convert_stream_full_response(self.build_stream())) + assert items[0] == "ping" + assert items[1]["event"] == "message" + assert "answer" in items[1] + assert items[2]["event"] == "message_end" + assert items[3]["event"] == "error" + + def test_convert_stream_simple_response(self): + items = list(AgentChatAppGenerateResponseConverter.convert_stream_simple_response(self.build_stream())) + assert items[0] == "ping" + # Assert the message event structure and content at items[1] + assert items[1]["event"] == "message" + assert items[1]["answer"] == "hi" or "hi" in items[1]["answer"] + assert items[2]["event"] == "message_end" + assert "metadata" in items[2] + metadata = items[2]["metadata"] + assert "annotation_reply" not in metadata + assert "usage" not in metadata + assert metadata["retriever_resources"] == [ + { + "segment_id": "s1", + "position": 1, + "document_name": "doc", + "score": 0.9, + "content": "content", + "summary": "summary", + } + ] + assert items[3]["event"] == "error" diff --git a/api/tests/unit_tests/core/app/apps/chat/__init__.py b/api/tests/unit_tests/core/app/apps/chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_config_manager.py b/api/tests/unit_tests/core/app/apps/chat/test_app_config_manager.py new file mode 100644 index 0000000000..271d007be6 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_config_manager.py @@ -0,0 +1,113 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom, ModelConfigEntity, PromptTemplateEntity +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from models.model import AppMode + + +class TestChatAppConfigManager: + def test_get_app_config_uses_override_dict(self): + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT.value) + app_model_config = SimpleNamespace(id="config-1", to_dict=lambda: {"model": "m"}) + override = {"model": "override"} + + model_entity = ModelConfigEntity(provider="p", model="m") + prompt_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="hi", + ) + + with ( + patch("core.app.apps.chat.app_config_manager.ModelConfigManager.convert", return_value=model_entity), + patch( + "core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.convert", return_value=prompt_entity + ), + patch( + "core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert", + return_value=None, + ), + patch("core.app.apps.chat.app_config_manager.DatasetConfigManager.convert", return_value=None), + patch("core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.convert", return_value=([], [])), + ): + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=None, + override_config_dict=override, + ) + + assert app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS + assert app_config.app_model_config_dict == override + assert app_config.app_mode == AppMode.CHAT + + def test_config_validate_filters_related_keys(self): + config = {"extra": 1} + + def _add_key(key, value): + def _inner(*args, **kwargs): + config = args[-1] + config = {**config, key: value} + return config, [key] + + return _inner + + with ( + patch( + "core.app.apps.chat.app_config_manager.ModelConfigManager.validate_and_set_defaults", + side_effect=_add_key("model", 1), + ), + patch( + "core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults", + side_effect=_add_key("inputs", 2), + ), + patch( + "core.app.apps.chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults", + side_effect=_add_key("file_upload", 3), + ), + patch( + "core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults", + side_effect=_add_key("prompt", 4), + ), + patch( + "core.app.apps.chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults", + side_effect=_add_key("dataset", 5), + ), + patch( + "core.app.apps.chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults", + side_effect=_add_key("opening_statement", 6), + ), + patch( + "core.app.apps.chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults", + side_effect=_add_key("suggested_questions_after_answer", 7), + ), + patch( + "core.app.apps.chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults", + side_effect=_add_key("speech_to_text", 8), + ), + patch( + "core.app.apps.chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults", + side_effect=_add_key("text_to_speech", 9), + ), + patch( + "core.app.apps.chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults", + side_effect=_add_key("retriever_resource", 10), + ), + patch( + "core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults", + side_effect=_add_key("sensitive_word_avoidance", 11), + ), + ): + filtered = ChatAppConfigManager.config_validate(tenant_id="t1", config=config) + + assert filtered["model"] == 1 + assert filtered["inputs"] == 2 + assert filtered["file_upload"] == 3 + assert filtered["prompt"] == 4 + assert filtered["dataset"] == 5 + assert filtered["opening_statement"] == 6 + assert filtered["suggested_questions_after_answer"] == 7 + assert filtered["speech_to_text"] == 8 + assert filtered["text_to_speech"] == 9 + assert filtered["retriever_resource"] == 10 + assert filtered["sensitive_word_avoidance"] == 11 diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py new file mode 100644 index 0000000000..3cdffbb4cd --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -0,0 +1,280 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.app.apps.chat.app_generator import ChatAppGenerator +from core.app.apps.chat.app_runner import ChatAppRunner +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueAnnotationReplyEvent +from core.moderation.base import ModerationError +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from models.model import AppMode + + +class DummyGenerateEntity: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +class DummyQueueManager: + def __init__(self, *args, **kwargs): + self.published = [] + + def publish_error(self, error, pub_from): + self.published.append((error, pub_from)) + + def publish(self, event, pub_from): + self.published.append((event, pub_from)) + + +class TestChatAppGenerator: + def test_generate_requires_query(self): + generator = ChatAppGenerator() + with pytest.raises(ValueError): + generator.generate( + app_model=SimpleNamespace(), + user=SimpleNamespace(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + def test_generate_rejects_non_string_query(self): + generator = ChatAppGenerator() + with pytest.raises(ValueError): + generator.generate( + app_model=SimpleNamespace(), + user=SimpleNamespace(), + args={"query": 1, "inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + def test_generate_debugger_overrides_model_config(self): + generator = ChatAppGenerator() + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + user = SimpleNamespace(id="user-1", session_id="session-1") + args = {"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}} + + with ( + patch("core.app.apps.chat.app_generator.ConversationService.get_conversation", return_value=None), + patch("core.app.apps.chat.app_generator.ChatAppConfigManager.config_validate", return_value={"x": 1}), + patch( + "core.app.apps.chat.app_generator.ChatAppConfigManager.get_app_config", + return_value=SimpleNamespace( + variables=[], external_data_variables=[], app_model_config_dict={}, app_mode=AppMode.CHAT + ), + ), + patch("core.app.apps.chat.app_generator.ModelConfigConverter.convert", return_value=SimpleNamespace()), + patch("core.app.apps.chat.app_generator.FileUploadConfigManager.convert", return_value=None), + patch("core.app.apps.chat.app_generator.file_factory.build_from_mappings", return_value=[]), + patch("core.app.apps.chat.app_generator.ChatAppGenerateEntity", DummyGenerateEntity), + patch("core.app.apps.chat.app_generator.TraceQueueManager", return_value=SimpleNamespace()), + patch("core.app.apps.chat.app_generator.MessageBasedAppQueueManager", DummyQueueManager), + patch( + "core.app.apps.chat.app_generator.ChatAppGenerateResponseConverter.convert", return_value={"ok": True} + ), + patch.object(ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {})), + patch.object(ChatAppGenerator, "_prepare_user_inputs", return_value={}), + patch.object( + ChatAppGenerator, + "_init_generate_records", + return_value=(SimpleNamespace(id="c1", mode="chat"), SimpleNamespace(id="m1")), + ), + patch.object(ChatAppGenerator, "_handle_response", return_value={"response": True}), + patch("core.app.apps.chat.app_generator.copy_current_request_context", side_effect=lambda f: f), + patch("core.app.apps.chat.app_generator.threading.Thread") as mock_thread, + ): + mock_thread.return_value.start.return_value = None + result = generator.generate(app_model, user, args, InvokeFrom.DEBUGGER, streaming=False) + + assert result == {"ok": True} + + def test_generate_rejects_model_config_override_for_non_debugger(self): + generator = ChatAppGenerator() + with pytest.raises(ValueError): + with ( + patch.object( + ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {}) + ), + ): + generator.generate( + app_model=SimpleNamespace(tenant_id="t1", id="a1", mode=AppMode.CHAT.value), + user=SimpleNamespace(id="u1", session_id="s1"), + args={"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + def test_generate_worker_handles_exceptions(self): + generator = ChatAppGenerator() + queue_manager = DummyQueueManager() + entity = DummyGenerateEntity(task_id="t1", user_id="u1") + + with ( + patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()), + patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()), + patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=InvokeAuthorizationError()), + patch("core.app.apps.chat.app_generator.db.session.close"), + ): + generator._generate_worker( + flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))), + application_generate_entity=entity, + queue_manager=queue_manager, + conversation_id="c1", + message_id="m1", + ) + + assert queue_manager.published + + with ( + patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()), + patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()), + patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=GenerateTaskStoppedError()), + patch("core.app.apps.chat.app_generator.db.session.close"), + ): + generator._generate_worker( + flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))), + application_generate_entity=entity, + queue_manager=queue_manager, + conversation_id="c1", + message_id="m1", + ) + + +class TestChatAppRunner: + def test_run_raises_when_app_missing(self): + runner = ChatAppRunner() + app_config = SimpleNamespace( + app_id="app-1", tenant_id="tenant-1", prompt_template=None, external_data_variables=[] + ) + app_generate_entity = DummyGenerateEntity( + app_config=app_config, + model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}), + inputs={}, + query="hi", + files=[], + file_upload_config=None, + conversation_id=None, + stream=False, + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + with patch("core.app.apps.chat.app_runner.db.session.scalar", return_value=None): + with pytest.raises(ValueError): + runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1")) + + def test_run_moderation_error_direct_output(self): + runner = ChatAppRunner() + app_config = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + prompt_template=None, + external_data_variables=[], + dataset=None, + additional_features=None, + ) + app_generate_entity = DummyGenerateEntity( + app_config=app_config, + model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}), + inputs={}, + query="hi", + files=[], + file_upload_config=None, + conversation_id=None, + stream=False, + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + with ( + patch( + "core.app.apps.chat.app_runner.db.session.scalar", + return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + ), + patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])), + patch.object(ChatAppRunner, "moderation_for_inputs", side_effect=ModerationError("blocked")), + patch.object(ChatAppRunner, "direct_output") as mock_direct, + ): + runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1")) + + mock_direct.assert_called_once() + + def test_run_annotation_reply_short_circuits(self): + runner = ChatAppRunner() + app_config = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + prompt_template=None, + external_data_variables=[], + dataset=None, + additional_features=None, + ) + app_generate_entity = DummyGenerateEntity( + app_config=app_config, + model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}), + inputs={}, + query="hi", + files=[], + file_upload_config=None, + conversation_id=None, + stream=False, + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + annotation = SimpleNamespace(id="ann-1", content="answer") + + with ( + patch( + "core.app.apps.chat.app_runner.db.session.scalar", + return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + ), + patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])), + patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")), + patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=annotation), + patch.object(ChatAppRunner, "direct_output") as mock_direct, + ): + queue_manager = DummyQueueManager() + runner.run(app_generate_entity, queue_manager, SimpleNamespace(), SimpleNamespace(id="m1")) + + assert any(isinstance(item[0], QueueAnnotationReplyEvent) for item in queue_manager.published) + mock_direct.assert_called_once() + + def test_run_returns_when_hosting_moderation_blocks(self): + runner = ChatAppRunner() + app_config = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + prompt_template=None, + external_data_variables=[], + dataset=None, + additional_features=None, + ) + app_generate_entity = DummyGenerateEntity( + app_config=app_config, + model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}), + inputs={}, + query="hi", + files=[], + file_upload_config=None, + conversation_id=None, + stream=False, + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + with ( + patch( + "core.app.apps.chat.app_runner.db.session.scalar", + return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + ), + patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])), + patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")), + patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None), + patch.object(ChatAppRunner, "check_hosting_moderation", return_value=True), + ): + runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1")) diff --git a/api/tests/unit_tests/core/app/apps/chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/chat/test_generate_response_converter.py new file mode 100644 index 0000000000..01272ba052 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/chat/test_generate_response_converter.py @@ -0,0 +1,65 @@ +from collections.abc import Generator + +from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + MessageStreamResponse, + PingStreamResponse, +) + + +class TestChatAppGenerateResponseConverter: + def test_convert_blocking_simple_response_metadata(self): + data = ChatbotAppBlockingResponse.Data( + id="msg-1", + mode="chat", + conversation_id="c1", + message_id="m1", + answer="hi", + metadata={"usage": {"total_tokens": 1}}, + created_at=1, + ) + blocking = ChatbotAppBlockingResponse(task_id="t1", data=data) + + response = ChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert "usage" not in response["metadata"] + + def test_convert_stream_responses(self): + def stream() -> Generator[ChatbotAppStreamResponse, None, None]: + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=PingStreamResponse(task_id="t1"), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=MessageStreamResponse(task_id="t1", id="m1", answer="hi"), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=MessageEndStreamResponse(task_id="t1", id="m1"), + ) + + full = list(ChatAppGenerateResponseConverter.convert_stream_full_response(stream())) + assert full[0] == "ping" + assert full[1]["event"] == "message" + assert full[2]["event"] == "error" + + simple = list(ChatAppGenerateResponseConverter.convert_stream_simple_response(stream())) + assert simple[0] == "ping" + assert simple[-1]["event"] == "message_end" diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py new file mode 100644 index 0000000000..51f33bac35 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -0,0 +1,162 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.app.apps.completion.app_runner as module +from core.app.apps.completion.app_runner import CompletionAppRunner +from core.moderation.base import ModerationError +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + + +@pytest.fixture +def runner(): + return CompletionAppRunner() + + +def _build_app_config(dataset=None, external_tools=None, additional_features=None): + app_config = MagicMock() + app_config.app_id = "app1" + app_config.tenant_id = "tenant" + app_config.prompt_template = MagicMock() + app_config.dataset = dataset + app_config.external_data_variables = external_tools or [] + app_config.additional_features = additional_features + app_config.app_model_config_dict = {"file_upload": {"enabled": True}} + return app_config + + +def _build_generate_entity(app_config, file_upload_config=None): + model_conf = MagicMock( + provider_model_bundle="bundle", + model="model", + parameters={"max_tokens": 10}, + stop=["stop"], + ) + return SimpleNamespace( + app_config=app_config, + model_conf=model_conf, + inputs={"qvar": "query_from_input"}, + query="original_query", + files=[], + file_upload_config=file_upload_config, + stream=True, + user_id="user", + invoke_from=MagicMock(), + ) + + +class TestCompletionAppRunner: + def test_run_app_not_found(self, runner, mocker): + session = mocker.MagicMock() + session.scalar.return_value = None + mocker.patch.object(module.db, "session", session) + + app_config = _build_app_config() + app_generate_entity = _build_generate_entity(app_config) + + with pytest.raises(ValueError): + runner.run(app_generate_entity, MagicMock(), MagicMock()) + + def test_run_moderation_error_outputs_direct(self, runner, mocker): + app_record = MagicMock(id="app1", tenant_id="tenant") + + session = mocker.MagicMock() + session.scalar.return_value = app_record + mocker.patch.object(module.db, "session", session) + + app_config = _build_app_config() + app_generate_entity = _build_generate_entity(app_config) + + runner.organize_prompt_messages = MagicMock(return_value=([], None)) + runner.moderation_for_inputs = MagicMock(side_effect=ModerationError("blocked")) + runner.direct_output = MagicMock() + runner._handle_invoke_result = MagicMock() + + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) + + runner.direct_output.assert_called_once() + runner._handle_invoke_result.assert_not_called() + + def test_run_hosting_moderation_stops(self, runner, mocker): + app_record = MagicMock(id="app1", tenant_id="tenant") + + session = mocker.MagicMock() + session.scalar.return_value = app_record + mocker.patch.object(module.db, "session", session) + + app_config = _build_app_config() + app_generate_entity = _build_generate_entity(app_config) + + runner.organize_prompt_messages = MagicMock(return_value=([], None)) + runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query")) + runner.check_hosting_moderation = MagicMock(return_value=True) + runner._handle_invoke_result = MagicMock() + + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) + + runner._handle_invoke_result.assert_not_called() + + def test_run_dataset_and_external_tools_flow(self, runner, mocker): + app_record = MagicMock(id="app1", tenant_id="tenant") + + session = mocker.MagicMock() + session.scalar.return_value = app_record + session.close = MagicMock() + mocker.patch.object(module.db, "session", session) + + retrieve_config = MagicMock(query_variable="qvar") + dataset_config = MagicMock(dataset_ids=["ds"], retrieve_config=retrieve_config) + additional_features = MagicMock(show_retrieve_source=True) + app_config = _build_app_config( + dataset=dataset_config, + external_tools=["tool"], + additional_features=additional_features, + ) + + file_upload_config = MagicMock() + file_upload_config.image_config.detail = ImagePromptMessageContent.DETAIL.HIGH + + app_generate_entity = _build_generate_entity(app_config, file_upload_config=file_upload_config) + + runner.organize_prompt_messages = MagicMock(side_effect=[(["pm1"], ["stop"]), (["pm2"], ["stop"])]) + runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query")) + runner.fill_in_inputs_from_external_data_tools = MagicMock(return_value=app_generate_entity.inputs) + runner.check_hosting_moderation = MagicMock(return_value=False) + runner.recalc_llm_max_tokens = MagicMock() + runner._handle_invoke_result = MagicMock() + + dataset_retrieval = MagicMock() + dataset_retrieval.retrieve.return_value = ("ctx", ["file1"]) + mocker.patch.object(module, "DatasetRetrieval", return_value=dataset_retrieval) + + model_instance = MagicMock() + model_instance.invoke_llm.return_value = "invoke_result" + mocker.patch.object(module, "ModelInstance", return_value=model_instance) + + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant")) + + dataset_retrieval.retrieve.assert_called_once() + assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input" + runner._handle_invoke_result.assert_called_once() + + def test_run_uses_low_image_detail_default(self, runner, mocker): + app_record = MagicMock(id="app1", tenant_id="tenant") + + session = mocker.MagicMock() + session.scalar.return_value = app_record + mocker.patch.object(module.db, "session", session) + + app_config = _build_app_config() + app_generate_entity = _build_generate_entity(app_config, file_upload_config=None) + + runner.organize_prompt_messages = MagicMock(return_value=([], None)) + runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query")) + runner.check_hosting_moderation = MagicMock(return_value=True) + + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) + + assert ( + runner.organize_prompt_messages.call_args.kwargs["image_detail_config"] + == ImagePromptMessageContent.DETAIL.LOW + ) diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_app_config_manager.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_app_config_manager.py new file mode 100644 index 0000000000..024bd8f302 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_app_config_manager.py @@ -0,0 +1,122 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import core.app.apps.completion.app_config_manager as module +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from models.model import AppMode + + +class TestCompletionAppConfigManager: + def test_get_app_config_with_override(self, mocker): + app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value) + app_model_config = MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "x"}} + + override_config = {"model": {"provider": "override"}} + + mocker.patch.object(module.ModelConfigManager, "convert", return_value="model") + mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt") + mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation") + mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset") + mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features") + mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=(["v1"], ["ext1"])) + mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + result = CompletionAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + override_config_dict=override_config, + ) + + assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS + assert result.app_model_config_dict == override_config + assert result.variables == ["v1"] + assert result.external_data_variables == ["ext1"] + assert result.app_mode == AppMode.COMPLETION + + def test_get_app_config_without_override_uses_model_config(self, mocker): + app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value) + app_model_config = MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "x"}} + + mocker.patch.object(module.ModelConfigManager, "convert", return_value="model") + mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt") + mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation") + mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset") + mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features") + mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=([], [])) + mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + result = CompletionAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) + + assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + assert result.app_model_config_dict == {"model": {"provider": "x"}} + + def test_config_validate_filters_related_keys(self, mocker): + config = { + "model": {"provider": "x"}, + "variables": ["v"], + "file_upload": {"enabled": True}, + "prompt": {"template": "t"}, + "dataset": {"enabled": True}, + "tts": {"enabled": True}, + "more_like_this": {"enabled": True}, + "moderation": {"enabled": True}, + "extra": "drop", + } + + mocker.patch.object( + module.ModelConfigManager, + "validate_and_set_defaults", + return_value=(config, ["model"]), + ) + mocker.patch.object( + module.BasicVariablesConfigManager, + "validate_and_set_defaults", + return_value=(config, ["variables"]), + ) + mocker.patch.object( + module.FileUploadConfigManager, + "validate_and_set_defaults", + return_value=(config, ["file_upload"]), + ) + mocker.patch.object( + module.PromptTemplateConfigManager, + "validate_and_set_defaults", + return_value=(config, ["prompt"]), + ) + mocker.patch.object( + module.DatasetConfigManager, + "validate_and_set_defaults", + return_value=(config, ["dataset"]), + ) + mocker.patch.object( + module.TextToSpeechConfigManager, + "validate_and_set_defaults", + return_value=(config, ["tts"]), + ) + mocker.patch.object( + module.MoreLikeThisConfigManager, + "validate_and_set_defaults", + return_value=(config, ["more_like_this"]), + ) + mocker.patch.object( + module.SensitiveWordAvoidanceConfigManager, + "validate_and_set_defaults", + return_value=(config, ["moderation"]), + ) + + filtered = CompletionAppConfigManager.config_validate("tenant", config) + + assert "extra" not in filtered + assert set(filtered.keys()) == { + "model", + "variables", + "file_upload", + "prompt", + "dataset", + "tts", + "more_like_this", + "moderation", + } diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py new file mode 100644 index 0000000000..2714757353 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -0,0 +1,321 @@ +import contextlib +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pydantic import ValidationError + +import core.app.apps.completion.app_generator as module +from core.app.apps.completion.app_generator import CompletionAppGenerator +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from services.errors.app import MoreLikeThisDisabledError +from services.errors.message import MessageNotExistsError + + +@pytest.fixture +def generator(mocker): + gen = CompletionAppGenerator() + + mocker.patch.object(module, "copy_current_request_context", side_effect=lambda fn: fn) + + flask_app = MagicMock() + flask_app.app_context.return_value = contextlib.nullcontext() + mocker.patch.object(module, "current_app", MagicMock(_get_current_object=MagicMock(return_value=flask_app))) + + thread = MagicMock() + mocker.patch.object(module.threading, "Thread", return_value=thread) + + mocker.patch.object(module, "MessageBasedAppQueueManager", return_value=MagicMock()) + mocker.patch.object(module, "TraceQueueManager", return_value=MagicMock()) + mocker.patch.object(module, "CompletionAppGenerateEntity", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + return gen + + +def _build_app_model(): + return MagicMock(tenant_id="tenant", id="app1", mode="completion") + + +def _build_user(): + return MagicMock(id="user", session_id="session") + + +def _build_app_model_config(): + config = MagicMock(id="cfg") + config.to_dict.return_value = {"model": {"provider": "x"}} + return config + + +class TestCompletionAppGenerator: + def test_generate_invalid_query_type(self, generator): + with pytest.raises(ValueError): + generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": 123, "inputs": {}, "files": []}, + invoke_from=InvokeFrom.WEB_APP, + streaming=True, + ) + + def test_generate_override_not_debugger(self, generator): + with pytest.raises(ValueError): + generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": "q", "inputs": {}, "files": [], "model_config": {}}, + invoke_from=InvokeFrom.WEB_APP, + streaming=False, + ) + + def test_generate_success_no_file_config(self, generator, mocker): + app_model_config = _build_app_model_config() + mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config) + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None) + mocker.patch.object(module.file_factory, "build_from_mappings") + + app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={})) + mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config) + mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock()) + + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + conversation = MagicMock(id="conv", mode="completion") + message = MagicMock(id="msg") + mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message)) + + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted") + + result = generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": "q", "inputs": {"a": 1}, "files": []}, + invoke_from=InvokeFrom.WEB_APP, + streaming=True, + ) + + assert result == "converted" + module.file_factory.build_from_mappings.assert_not_called() + + def test_generate_success_with_files(self, generator, mocker): + app_model_config = _build_app_model_config() + mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config) + + file_extra_config = MagicMock() + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config) + mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"]) + + app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={})) + mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config) + mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock()) + + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + conversation = MagicMock(id="conv", mode="completion") + message = MagicMock(id="msg") + mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message)) + + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted") + + result = generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": "q", "inputs": {"a": 1}, "files": [{"id": "f"}]}, + invoke_from=InvokeFrom.WEB_APP, + streaming=False, + ) + + assert result == "converted" + module.file_factory.build_from_mappings.assert_called_once() + + def test_generate_override_model_config_debugger(self, generator, mocker): + app_model_config = _build_app_model_config() + mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config) + + override_config = {"model": {"provider": "override"}} + mocker.patch.object(module.CompletionAppConfigManager, "config_validate", return_value=override_config) + + app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={})) + get_app_config = mocker.patch.object( + module.CompletionAppConfigManager, + "get_app_config", + return_value=app_config, + ) + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None) + mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock()) + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + mocker.patch.object( + generator, + "_init_generate_records", + return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")), + ) + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted") + + generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": "q", "inputs": {}, "files": [], "model_config": override_config}, + invoke_from=InvokeFrom.DEBUGGER, + streaming=True, + ) + + assert get_app_config.call_args.kwargs["override_config_dict"] == override_config + + def test_generate_more_like_this_message_not_found(self, generator, mocker): + session = mocker.MagicMock() + session.scalar.return_value = None + mocker.patch.object(module.db, "session", session) + + with pytest.raises(MessageNotExistsError): + generator.generate_more_like_this( + app_model=_build_app_model(), + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_more_like_this_disabled(self, generator, mocker): + app_model = _build_app_model() + app_model.app_model_config = MagicMock(more_like_this=False, more_like_this_dict={"enabled": False}) + + message = MagicMock() + session = mocker.MagicMock() + session.scalar.return_value = message + mocker.patch.object(module.db, "session", session) + + with pytest.raises(MoreLikeThisDisabledError): + generator.generate_more_like_this( + app_model=app_model, + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_more_like_this_app_model_config_missing(self, generator, mocker): + app_model = _build_app_model() + app_model.app_model_config = None + + message = MagicMock() + session = mocker.MagicMock() + session.scalar.return_value = message + mocker.patch.object(module.db, "session", session) + + with pytest.raises(MoreLikeThisDisabledError): + generator.generate_more_like_this( + app_model=app_model, + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_more_like_this_message_config_none(self, generator, mocker): + app_model = _build_app_model() + app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True}) + + message = MagicMock(app_model_config=None) + session = mocker.MagicMock() + session.scalar.return_value = message + mocker.patch.object(module.db, "session", session) + + with pytest.raises(ValueError): + generator.generate_more_like_this( + app_model=app_model, + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_more_like_this_success(self, generator, mocker): + app_model = _build_app_model() + app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True}) + + message = MagicMock() + message.message_files = [{"id": "f"}] + message.inputs = {"a": 1} + message.query = "q" + + app_model_config = MagicMock() + app_model_config.to_dict.return_value = { + "model": {"completion_params": {"temperature": 0.1}}, + "file_upload": {"enabled": True}, + } + message.app_model_config = app_model_config + + session = mocker.MagicMock() + session.scalar.return_value = message + mocker.patch.object(module.db, "session", session) + + file_extra_config = MagicMock() + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config) + mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"]) + + app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={})) + get_app_config = mocker.patch.object( + module.CompletionAppConfigManager, + "get_app_config", + return_value=app_config, + ) + mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock()) + + mocker.patch.object( + generator, + "_init_generate_records", + return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")), + ) + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted") + + result = generator.generate_more_like_this( + app_model=app_model, + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + stream=True, + ) + + assert result == "converted" + override_dict = get_app_config.call_args.kwargs["override_config_dict"] + assert override_dict["model"]["completion_params"]["temperature"] == 0.9 + + @pytest.mark.parametrize( + ("error", "should_publish"), + [ + (GenerateTaskStoppedError(), False), + (InvokeAuthorizationError("bad"), True), + ( + ValidationError.from_exception_data( + "Model", + [{"type": "missing", "loc": ("x",), "msg": "Field required", "input": {}}], + ), + True, + ), + (ValueError("bad"), True), + (RuntimeError("boom"), True), + ], + ) + def test_generate_worker_error_handling(self, generator, mocker, error, should_publish): + flask_app = MagicMock() + flask_app.app_context.return_value = contextlib.nullcontext() + + session = mocker.MagicMock() + mocker.patch.object(module.db, "session", session) + + mocker.patch.object(generator, "_get_message", return_value=MagicMock()) + + runner_instance = MagicMock() + runner_instance.run.side_effect = error + mocker.patch.object(module, "CompletionAppRunner", return_value=runner_instance) + + queue_manager = MagicMock() + generator._generate_worker( + flask_app=flask_app, + application_generate_entity=MagicMock(), + queue_manager=queue_manager, + message_id="msg", + ) + + assert queue_manager.publish_error.called is should_publish diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_generate_response_converter.py new file mode 100644 index 0000000000..cf473dfbeb --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_generate_response_converter.py @@ -0,0 +1,153 @@ +from collections.abc import Generator + +from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter +from core.app.entities.task_entities import ( + AppStreamResponse, + CompletionAppBlockingResponse, + CompletionAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + MessageStreamResponse, + PingStreamResponse, +) + + +class TestCompletionAppGenerateResponseConverter: + def test_convert_blocking_full_response(self): + blocking = CompletionAppBlockingResponse( + task_id="task", + data=CompletionAppBlockingResponse.Data( + id="id", + mode="completion", + message_id="msg", + answer="answer", + metadata={"k": "v"}, + created_at=123, + ), + ) + + result = CompletionAppGenerateResponseConverter.convert_blocking_full_response(blocking) + + assert result["event"] == "message" + assert result["task_id"] == "task" + assert result["message_id"] == "msg" + assert result["answer"] == "answer" + assert result["metadata"] == {"k": "v"} + + def test_convert_blocking_simple_response_metadata_simplified(self): + metadata = { + "retriever_resources": [ + { + "segment_id": "s", + "position": 1, + "document_name": "doc", + "score": 0.9, + "content": "c", + "summary": "sum", + "extra": "x", + } + ], + "annotation_reply": {"a": 1}, + "usage": {"t": 2}, + } + blocking = CompletionAppBlockingResponse( + task_id="task", + data=CompletionAppBlockingResponse.Data( + id="id", + mode="completion", + message_id="msg", + answer="answer", + metadata=metadata, + created_at=123, + ), + ) + + result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert "annotation_reply" not in result["metadata"] + assert "usage" not in result["metadata"] + assert result["metadata"]["retriever_resources"][0]["segment_id"] == "s" + assert "extra" not in result["metadata"]["retriever_resources"][0] + + def test_convert_blocking_simple_response_metadata_not_dict(self): + data = CompletionAppBlockingResponse.Data.model_construct( + id="id", + mode="completion", + message_id="msg", + answer="answer", + metadata="bad", + created_at=123, + ) + blocking = CompletionAppBlockingResponse.model_construct(task_id="task", data=data) + + result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert result["metadata"] == {} + + def test_convert_stream_full_response(self): + def stream() -> Generator[AppStreamResponse, None, None]: + yield CompletionAppStreamResponse( + stream_response=PingStreamResponse(task_id="t"), + message_id="m", + created_at=1, + ) + yield CompletionAppStreamResponse( + stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")), + message_id="m", + created_at=2, + ) + yield CompletionAppStreamResponse( + stream_response=MessageStreamResponse(task_id="t", id="1", answer="ok"), + message_id="m", + created_at=3, + ) + + result = list(CompletionAppGenerateResponseConverter.convert_stream_full_response(stream())) + + assert result[0] == "ping" + assert result[1]["event"] == "error" + assert result[1]["code"] == "invalid_param" + assert result[2]["event"] == "message" + + def test_convert_stream_simple_response(self): + def stream() -> Generator[AppStreamResponse, None, None]: + yield CompletionAppStreamResponse( + stream_response=PingStreamResponse(task_id="t"), + message_id="m", + created_at=1, + ) + yield CompletionAppStreamResponse( + stream_response=MessageEndStreamResponse( + task_id="t", + id="end", + metadata={ + "retriever_resources": [ + { + "segment_id": "s", + "position": 1, + "document_name": "doc", + "score": 0.9, + "content": "c", + "summary": "sum", + } + ], + "annotation_reply": {"a": 1}, + "usage": {"t": 2}, + }, + ), + message_id="m", + created_at=2, + ) + yield CompletionAppStreamResponse( + stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")), + message_id="m", + created_at=3, + ) + + result = list(CompletionAppGenerateResponseConverter.convert_stream_simple_response(stream())) + + assert result[0] == "ping" + assert result[1]["event"] == "message_end" + assert "annotation_reply" not in result[1]["metadata"] + assert "usage" not in result[1]["metadata"] + assert result[2]["event"] == "error" diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_config_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_config_manager.py new file mode 100644 index 0000000000..5d4c9bcde0 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_config_manager.py @@ -0,0 +1,55 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import core.app.apps.pipeline.pipeline_config_manager as module +from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager +from models.model import AppMode + + +def test_get_pipeline_config(mocker): + pipeline = MagicMock(tenant_id="tenant", id="pipe1") + workflow = MagicMock(id="wf1") + + mocker.patch.object( + module.WorkflowVariablesConfigManager, + "convert_rag_pipeline_variable", + return_value=["var1"], + ) + mocker.patch.object(module, "PipelineConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + result = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow, start_node_id="start") + + assert result.tenant_id == "tenant" + assert result.app_id == "pipe1" + assert result.workflow_id == "wf1" + assert result.app_mode == AppMode.RAG_PIPELINE + assert result.rag_pipeline_variables == ["var1"] + + +def test_config_validate_filters_related_keys(mocker): + config = { + "file_upload": {"enabled": True}, + "tts": {"enabled": True}, + "moderation": {"enabled": True}, + "extra": "drop", + } + + mocker.patch.object( + module.FileUploadConfigManager, + "validate_and_set_defaults", + return_value=(config, ["file_upload"]), + ) + mocker.patch.object( + module.TextToSpeechConfigManager, + "validate_and_set_defaults", + return_value=(config, ["tts"]), + ) + mocker.patch.object( + module.SensitiveWordAvoidanceConfigManager, + "validate_and_set_defaults", + return_value=(config, ["moderation"]), + ) + + filtered = PipelineConfigManager.config_validate("tenant", config) + + assert set(filtered.keys()) == {"file_upload", "tts", "moderation"} diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py new file mode 100644 index 0000000000..94ed8166b9 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -0,0 +1,111 @@ +from collections.abc import Generator + +from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter +from core.app.entities.task_entities import ( + AppStreamResponse, + ErrorStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) +from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + + +def test_convert_blocking_full_and_simple_response(): + blocking = WorkflowAppBlockingResponse( + task_id="task", + workflow_run_id="run", + data=WorkflowAppBlockingResponse.Data( + id="id", + workflow_id="wf", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={"k": "v"}, + error=None, + elapsed_time=0.1, + total_tokens=10, + total_steps=1, + created_at=1, + finished_at=2, + ), + ) + + full = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking) + simple = WorkflowAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert full == simple + assert full["workflow_run_id"] == "run" + assert full["data"]["status"] == WorkflowExecutionStatus.SUCCEEDED + + +def test_convert_stream_full_response(): + def stream() -> Generator[AppStreamResponse, None, None]: + yield WorkflowAppStreamResponse( + stream_response=PingStreamResponse(task_id="t"), + workflow_run_id="run", + ) + yield WorkflowAppStreamResponse( + stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")), + workflow_run_id="run", + ) + + result = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(stream())) + + assert result[0] == "ping" + assert result[1]["event"] == "error" + assert result[1]["code"] == "invalid_param" + + +def test_convert_stream_simple_response_node_ignore_details(): + node_start = NodeStartStreamResponse( + task_id="t", + workflow_run_id="run", + data=NodeStartStreamResponse.Data( + id="nid", + node_id="node", + node_type="type", + title="Title", + index=1, + predecessor_node_id=None, + inputs={"a": 1}, + inputs_truncated=False, + created_at=1, + ), + ) + node_finish = NodeFinishStreamResponse( + task_id="t", + workflow_run_id="run", + data=NodeFinishStreamResponse.Data( + id="nid", + node_id="node", + node_type="type", + title="Title", + index=1, + predecessor_node_id=None, + inputs={"a": 1}, + inputs_truncated=False, + process_data=None, + process_data_truncated=False, + outputs={"b": 2}, + outputs_truncated=False, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + error=None, + elapsed_time=0.1, + execution_metadata=None, + created_at=1, + finished_at=2, + files=[], + ), + ) + + def stream() -> Generator[AppStreamResponse, None, None]: + yield WorkflowAppStreamResponse(stream_response=node_start, workflow_run_id="run") + yield WorkflowAppStreamResponse(stream_response=node_finish, workflow_run_id="run") + + result = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream())) + + assert result[0]["event"] == "node_started" + assert result[0]["data"]["inputs"] is None + assert result[1]["event"] == "node_finished" + assert result[1]["data"]["inputs"] is None diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py new file mode 100644 index 0000000000..06face41fe --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py @@ -0,0 +1,699 @@ +import contextlib +from types import SimpleNamespace +from unittest.mock import MagicMock, PropertyMock + +import pytest + +import core.app.apps.pipeline.pipeline_generator as module +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from core.datasource.entities.datasource_entities import DatasourceProviderType + + +class FakeRagPipelineGenerateEntity(SimpleNamespace): + class SingleIterationRunEntity(SimpleNamespace): + pass + + class SingleLoopRunEntity(SimpleNamespace): + pass + + def model_dump(self): + return dict(self.__dict__) + + +@pytest.fixture +def generator(mocker): + gen = module.PipelineGenerator() + + mocker.patch.object(module, "RagPipelineGenerateEntity", FakeRagPipelineGenerateEntity) + mocker.patch.object(module, "RagPipelineInvokeEntity", side_effect=lambda **kwargs: kwargs) + mocker.patch.object(module.contexts, "plugin_tool_providers", SimpleNamespace(set=MagicMock())) + mocker.patch.object(module.contexts, "plugin_tool_providers_lock", SimpleNamespace(set=MagicMock())) + + return gen + + +def _build_pipeline_dataset(): + return SimpleNamespace( + id="ds", + name="dataset", + description="desc", + chunk_structure="chunk", + built_in_field_enabled=True, + tenant_id="tenant", + ) + + +def _build_pipeline(): + pipeline = MagicMock(tenant_id="tenant", id="pipe") + pipeline.retrieve_dataset.return_value = _build_pipeline_dataset() + return pipeline + + +def _build_workflow(): + return MagicMock(id="wf", graph_dict={"nodes": [], "edges": []}, tenant_id="tenant") + + +def _build_user(): + return MagicMock(id="user", name="User", session_id="session") + + +def _build_args(): + return { + "inputs": {"k": "v"}, + "start_node_id": "start", + "datasource_type": DatasourceProviderType.LOCAL_FILE.value, + "datasource_info_list": [{"name": "file"}], + } + + +def _patch_session(mocker, session): + mocker.patch.object(module, "Session", return_value=session) + mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock()) + + +def _dummy_preserve(*args, **kwargs): + return contextlib.nullcontext() + + +class DummySession: + def __init__(self): + self.scalar = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +def test_generate_dataset_missing(generator, mocker): + pipeline = _build_pipeline() + pipeline.retrieve_dataset.return_value = None + + session = DummySession() + _patch_session(mocker, session) + + with pytest.raises(ValueError): + generator.generate( + pipeline=pipeline, + workflow=_build_workflow(), + user=_build_user(), + args=_build_args(), + invoke_from=InvokeFrom.WEB_APP, + streaming=False, + ) + + +def test_generate_debugger_calls_generate(generator, mocker): + pipeline = _build_pipeline() + workflow = _build_workflow() + + session = DummySession() + _patch_session(mocker, session) + + mocker.patch.object( + generator, + "_format_datasource_info_list", + return_value=[{"name": "file"}], + ) + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]), + ) + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + + mocker.patch.object(generator, "_generate", return_value={"result": "ok"}) + + result = generator.generate( + pipeline=pipeline, + workflow=workflow, + user=_build_user(), + args=_build_args(), + invoke_from=InvokeFrom.DEBUGGER, + streaming=True, + ) + + assert result == {"result": "ok"} + + +def test_generate_published_pipeline_creates_documents_and_delay(generator, mocker): + pipeline = _build_pipeline() + workflow = _build_workflow() + + session = DummySession() + _patch_session(mocker, session) + + datasource_info_list = [{"name": "file1"}, {"name": "file2"}] + + mocker.patch.object( + generator, + "_format_datasource_info_list", + return_value=datasource_info_list, + ) + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]), + ) + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + mocker.patch("services.dataset_service.DocumentService.get_documents_position", return_value=1) + + document1 = SimpleNamespace( + id="doc1", + position=1, + data_source_type=DatasourceProviderType.LOCAL_FILE, + data_source_info="{}", + name="file1", + indexing_status="", + error=None, + enabled=True, + ) + document2 = SimpleNamespace( + id="doc2", + position=2, + data_source_type=DatasourceProviderType.LOCAL_FILE, + data_source_info="{}", + name="file2", + indexing_status="", + error=None, + enabled=True, + ) + mocker.patch.object(generator, "_build_document", side_effect=[document1, document2]) + + mocker.patch.object(module, "DocumentPipelineExecutionLog", return_value=MagicMock()) + + db_session = MagicMock() + mocker.patch.object(module.db, "session", db_session) + + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + + task_proxy = MagicMock() + mocker.patch.object(module, "RagPipelineTaskProxy", return_value=task_proxy) + + result = generator.generate( + pipeline=pipeline, + workflow=workflow, + user=_build_user(), + args=_build_args(), + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, + streaming=False, + ) + + assert result["batch"] + assert len(result["documents"]) == 2 + task_proxy.delay.assert_called_once() + + +def test_generate_is_retry_calls_generate(generator, mocker): + pipeline = _build_pipeline() + workflow = _build_workflow() + + session = DummySession() + _patch_session(mocker, session) + + mocker.patch.object( + generator, + "_format_datasource_info_list", + return_value=[{"name": "file"}], + ) + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]), + ) + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + + mocker.patch.object(generator, "_generate", return_value={"result": "ok"}) + + result = generator.generate( + pipeline=pipeline, + workflow=workflow, + user=_build_user(), + args=_build_args(), + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, + streaming=True, + is_retry=True, + ) + + assert result == {"result": "ok"} + + +def test_generate_worker_handles_errors(generator, mocker): + flask_app = MagicMock() + flask_app.app_context.return_value = contextlib.nullcontext() + mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) + mocker.patch.object(module.db, "session", MagicMock(close=MagicMock())) + mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock()) + + application_generate_entity = FakeRagPipelineGenerateEntity( + app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"), + invoke_from=InvokeFrom.WEB_APP, + user_id="user", + ) + + session = DummySession() + session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")] + _patch_session(mocker, session) + + runner_instance = MagicMock() + runner_instance.run.side_effect = ValueError("bad") + mocker.patch.object(module, "PipelineRunner", return_value=runner_instance) + + queue_manager = MagicMock() + generator._generate_worker( + flask_app=flask_app, + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + context=contextlib.nullcontext(), + variable_loader=MagicMock(), + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + queue_manager.publish_error.assert_called_once() + + +def test_generate_worker_sets_system_user_id_for_external_call(generator, mocker): + flask_app = MagicMock() + flask_app.app_context.return_value = contextlib.nullcontext() + mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) + mocker.patch.object(module.db, "session", MagicMock(close=MagicMock())) + mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock()) + + application_generate_entity = FakeRagPipelineGenerateEntity( + app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"), + invoke_from=InvokeFrom.WEB_APP, + user_id="user", + ) + + session = DummySession() + session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")] + _patch_session(mocker, session) + + runner_instance = MagicMock() + mocker.patch.object(module, "PipelineRunner", return_value=runner_instance) + + generator._generate_worker( + flask_app=flask_app, + application_generate_entity=application_generate_entity, + queue_manager=MagicMock(), + context=contextlib.nullcontext(), + variable_loader=MagicMock(), + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + assert module.PipelineRunner.call_args.kwargs["system_user_id"] == "session" + + +def test_generate_raises_when_workflow_not_found(generator, mocker): + flask_app = MagicMock() + mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = None + mocker.patch.object(module.db, "session", session) + + with pytest.raises(ValueError): + generator._generate( + flask_app=flask_app, + context=contextlib.nullcontext(), + pipeline=_build_pipeline(), + workflow_id="wf", + user=_build_user(), + application_generate_entity=FakeRagPipelineGenerateEntity( + task_id="t", + app_config=SimpleNamespace(app_id="pipe"), + user_id="user", + invoke_from=InvokeFrom.DEBUGGER, + ), + invoke_from=InvokeFrom.DEBUGGER, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + streaming=True, + ) + + +def test_generate_success_returns_converted(generator, mocker): + flask_app = MagicMock() + mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) + + workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={}) + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = workflow + mocker.patch.object(module.db, "session", session) + + queue_manager = MagicMock() + mocker.patch.object(module, "PipelineQueueManager", return_value=queue_manager) + + worker_thread = MagicMock() + mocker.patch.object(module.threading, "Thread", return_value=worker_thread) + + mocker.patch.object(generator, "_get_draft_var_saver_factory", return_value=MagicMock()) + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.WorkflowAppGenerateResponseConverter, "convert", return_value="converted") + + result = generator._generate( + flask_app=flask_app, + context=contextlib.nullcontext(), + pipeline=_build_pipeline(), + workflow_id="wf", + user=_build_user(), + application_generate_entity=FakeRagPipelineGenerateEntity( + task_id="t", + app_config=SimpleNamespace(app_id="pipe"), + user_id="user", + invoke_from=InvokeFrom.DEBUGGER, + ), + invoke_from=InvokeFrom.DEBUGGER, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + streaming=True, + ) + + assert result == "converted" + + +def test_single_iteration_generate_validates_inputs(generator, mocker): + with pytest.raises(ValueError): + generator.single_iteration_generate(_build_pipeline(), _build_workflow(), "", _build_user(), {}) + + with pytest.raises(ValueError): + generator.single_iteration_generate( + _build_pipeline(), _build_workflow(), "node", _build_user(), {"inputs": None} + ) + + +def test_single_iteration_generate_dataset_required(generator, mocker): + pipeline = _build_pipeline() + pipeline.retrieve_dataset.return_value = None + + session = DummySession() + _patch_session(mocker, session) + + with pytest.raises(ValueError): + generator.single_iteration_generate( + pipeline, + _build_workflow(), + "node", + _build_user(), + {"inputs": {"a": 1}}, + ) + + +def test_single_iteration_generate_success(generator, mocker): + pipeline = _build_pipeline() + + session = DummySession() + _patch_session(mocker, session) + + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock())) + + mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock()) + mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock()) + + mocker.patch.object(generator, "_generate", return_value={"ok": True}) + + result = generator.single_iteration_generate( + pipeline, + _build_workflow(), + "node", + _build_user(), + {"inputs": {"a": 1}}, + streaming=False, + ) + + assert result == {"ok": True} + + +def test_single_loop_generate_success(generator, mocker): + pipeline = _build_pipeline() + + session = DummySession() + _patch_session(mocker, session) + + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock())) + + mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock()) + mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock()) + + mocker.patch.object(generator, "_generate", return_value={"ok": True}) + + result = generator.single_loop_generate( + pipeline, + _build_workflow(), + "node", + _build_user(), + {"inputs": {"a": 1}}, + streaming=False, + ) + + assert result == {"ok": True} + + +def test_handle_response_value_error_triggers_generate_task_stopped(generator, mocker): + pipeline = _build_pipeline() + workflow = _build_workflow() + app_entity = FakeRagPipelineGenerateEntity(task_id="t") + + task_pipeline = MagicMock() + task_pipeline.process.side_effect = ValueError("I/O operation on closed file.") + mocker.patch.object(module, "WorkflowAppGenerateTaskPipeline", return_value=task_pipeline) + + with pytest.raises(GenerateTaskStoppedError): + generator._handle_response( + application_generate_entity=app_entity, + workflow=workflow, + queue_manager=MagicMock(), + user=_build_user(), + draft_var_saver_factory=MagicMock(), + stream=False, + ) + + +def test_build_document_sets_metadata_for_builtin_fields(generator, mocker): + class DummyDocument(SimpleNamespace): + pass + + mocker.patch.object(module, "Document", side_effect=lambda **kwargs: DummyDocument(**kwargs)) + + document = generator._build_document( + tenant_id="tenant", + dataset_id="ds", + built_in_field_enabled=True, + datasource_type=DatasourceProviderType.LOCAL_FILE, + datasource_info={"name": "file"}, + created_from="rag-pipeline", + position=1, + account=_build_user(), + batch="batch", + document_form="text", + ) + + assert document.name == "file" + assert document.doc_metadata + + +def test_build_document_invalid_datasource_type(generator): + with pytest.raises(ValueError): + generator._build_document( + tenant_id="tenant", + dataset_id="ds", + built_in_field_enabled=False, + datasource_type="invalid", + datasource_info={}, + created_from="rag-pipeline", + position=1, + account=_build_user(), + batch="batch", + document_form="text", + ) + + +def test_format_datasource_info_list_non_online_drive(generator): + result = generator._format_datasource_info_list( + DatasourceProviderType.LOCAL_FILE, + [{"name": "file"}], + _build_pipeline(), + _build_workflow(), + "start", + _build_user(), + ) + + assert result == [{"name": "file"}] + + +def test_format_datasource_info_list_missing_node_data(generator): + workflow = MagicMock(graph_dict={"nodes": []}) + + with pytest.raises(ValueError): + generator._format_datasource_info_list( + DatasourceProviderType.ONLINE_DRIVE, + [], + _build_pipeline(), + workflow, + "start", + _build_user(), + ) + + +def test_format_datasource_info_list_online_drive_folder(generator, mocker): + workflow = MagicMock( + graph_dict={ + "nodes": [ + { + "id": "start", + "data": { + "plugin_id": "p", + "provider_name": "provider", + "datasource_name": "drive", + "credential_id": "cred", + }, + } + ] + } + ) + + runtime = MagicMock() + runtime.runtime = SimpleNamespace(credentials=None) + runtime.datasource_provider_type.return_value = DatasourceProviderType.ONLINE_DRIVE + + mocker.patch( + "core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", + return_value=runtime, + ) + mocker.patch.object(module.DatasourceProviderService, "get_datasource_credentials", return_value={"k": "v"}) + + mocker.patch.object( + generator, + "_get_files_in_folder", + side_effect=lambda *args, **kwargs: args[4].append({"id": "f"}), + ) + + result = generator._format_datasource_info_list( + DatasourceProviderType.ONLINE_DRIVE, + [{"id": "folder", "type": "folder", "name": "Folder", "bucket": "b"}], + _build_pipeline(), + workflow, + "start", + _build_user(), + ) + + assert result == [{"id": "f"}] + + +def test_get_files_in_folder_recurses_and_collects(generator): + class File: + def __init__(self, id, name, type): + self.id = id + self.name = name + self.type = type + + class FilesPage: + def __init__(self, files, is_truncated=False, next_page_parameters=None): + self.files = files + self.is_truncated = is_truncated + self.next_page_parameters = next_page_parameters + + class Result: + def __init__(self, result): + self.result = result + + class Runtime: + def __init__(self): + self.calls = [] + + def datasource_provider_type(self): + return DatasourceProviderType.ONLINE_DRIVE + + def online_drive_browse_files(self, user_id, request, provider_type): + self.calls.append(request.next_page_parameters) + if request.prefix == "fd": + return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])]) + if request.next_page_parameters is None: + return iter( + [ + Result( + [FilesPage([File("f1", "file", "file"), File("fd", "folder", "folder")], True, {"page": 2})] + ) + ] + ) + return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])]) + + runtime = Runtime() + all_files = [] + + generator._get_files_in_folder( + datasource_runtime=runtime, + prefix="root", + bucket="b", + user_id="user", + all_files=all_files, + datasource_info={}, + ) + + assert {f["id"] for f in all_files} == {"f1", "f2"} diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py new file mode 100644 index 0000000000..72f7552bd1 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -0,0 +1,57 @@ +import pytest + +import core.app.apps.pipeline.pipeline_queue_manager as module +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + QueueErrorEvent, + QueueMessageEndEvent, + QueueStopEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowSucceededEvent, +) +from dify_graph.model_runtime.entities.llm_entities import LLMResult + + +def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): + manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag") + manager._q = mocker.MagicMock() + manager.stop_listen = mocker.MagicMock() + manager._is_stopped = mocker.MagicMock(return_value=True) + + with pytest.raises(GenerateTaskStoppedError): + manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.APPLICATION_MANAGER) + + manager.stop_listen.assert_called_once() + + +def test_publish_stop_events_trigger_stop_listen(mocker): + manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag") + manager._q = mocker.MagicMock() + manager.stop_listen = mocker.MagicMock() + manager._is_stopped = mocker.MagicMock(return_value=False) + + for event in [ + QueueErrorEvent(error=ValueError("bad")), + QueueMessageEndEvent(llm_result=LLMResult.model_construct()), + QueueWorkflowSucceededEvent(), + QueueWorkflowFailedEvent(error="failed", exceptions_count=1), + QueueWorkflowPartialSuccessEvent(exceptions_count=1), + ]: + manager.stop_listen.reset_mock() + manager._publish(event, PublishFrom.TASK_PIPELINE) + manager.stop_listen.assert_called_once() + + +def test_publish_non_stop_event_no_stop_listen(mocker): + manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag") + manager._q = mocker.MagicMock() + manager.stop_listen = mocker.MagicMock() + manager._is_stopped = mocker.MagicMock(return_value=False) + + non_stop_event = mocker.MagicMock(spec=module.AppQueueEvent) + manager._publish(non_stop_event, PublishFrom.TASK_PIPELINE) + manager.stop_listen.assert_not_called() diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py new file mode 100644 index 0000000000..eec95b7f39 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -0,0 +1,297 @@ +""" +Unit tests for PipelineRunner behavior. +Asserts correct event handling, error propagation, and user invocation logic. +Primary collaborators: PipelineRunner, InvokeFrom, GraphRunFailedEvent, UserFrom, and mocked dependencies. +Cross-references: core.app.apps.pipeline.pipeline_runner, core.app.entities.app_invoke_entities. +""" + +"""Unit tests for PipelineRunner behavior. + +This module validates core control-flow outcomes for +``core.app.apps.pipeline.pipeline_runner``: app/workflow lookup, graph +initialization guards, invoke-source to user-source resolution, and failed-run +event handling. Invariants asserted here include strict graph-config +validation, correct ``InvokeFrom`` to ``UserFrom`` mapping, and publishing +error paths driven by ``GraphRunFailedEvent`` through mocked collaborators. +Primary collaborators include ``PipelineRunner``, +``core.app.entities.app_invoke_entities.InvokeFrom``, ``GraphRunFailedEvent``, +``UserFrom``, and patched DB/runtime dependencies used by the runner. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.app.apps.pipeline.pipeline_runner as module +from core.app.apps.pipeline.pipeline_runner import PipelineRunner +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.graph_events import GraphRunFailedEvent + + +def _build_app_generate_entity() -> SimpleNamespace: + app_config = SimpleNamespace(app_id="pipe", workflow_id="wf", tenant_id="tenant") + return SimpleNamespace( + app_config=app_config, + invoke_from=InvokeFrom.WEB_APP, + user_id="user", + trace_manager=MagicMock(), + inputs={"input1": "v1"}, + files=[], + workflow_execution_id="run", + document_id="doc", + original_document_id=None, + batch="batch", + dataset_id="ds", + datasource_type="local_file", + datasource_info={"name": "file"}, + start_node_id="start", + call_depth=0, + single_iteration_run=None, + single_loop_run=None, + ) + + +@pytest.fixture +def runner(): + app_generate_entity = _build_app_generate_entity() + queue_manager = MagicMock() + variable_loader = MagicMock() + workflow = MagicMock() + workflow_execution_repository = MagicMock() + workflow_node_execution_repository = MagicMock() + + return PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=queue_manager, + variable_loader=variable_loader, + workflow=workflow, + system_user_id="sys", + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + ) + + +def test_get_app_id(runner): + assert runner._get_app_id() == "pipe" + + +def test_get_workflow_returns_workflow(mocker, runner): + pipeline = MagicMock(tenant_id="tenant", id="pipe") + workflow = MagicMock(id="wf") + + query = MagicMock() + query.where.return_value.first.return_value = workflow + mocker.patch.object(module.db, "session", MagicMock(query=MagicMock(return_value=query))) + + result = runner.get_workflow(pipeline=pipeline, workflow_id="wf") + + assert result == workflow + + +def test_init_rag_pipeline_graph_invalid_config(mocker, runner): + workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={}) + + with pytest.raises(ValueError): + runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock()) + + workflow.graph_dict = {"nodes": "bad", "edges": []} + with pytest.raises(ValueError): + runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock()) + + workflow.graph_dict = {"nodes": [], "edges": "bad"} + with pytest.raises(ValueError): + runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock()) + + +def test_init_rag_pipeline_graph_not_found(mocker, runner): + workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={"nodes": [], "edges": []}) + mocker.patch.object(module.Graph, "init", return_value=None) + + with pytest.raises(ValueError): + runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock()) + + +def test_update_document_status_on_failure(mocker, runner): + document = MagicMock() + + query = MagicMock() + query.where.return_value.first.return_value = document + + session = MagicMock() + session.query.return_value = query + mocker.patch.object(module.db, "session", session) + + event = GraphRunFailedEvent(error="boom") + + runner._update_document_status(event, document_id="doc", dataset_id="ds") + + assert document.indexing_status == "error" + assert document.error == "boom" + session.commit.assert_called_once() + + +def test_run_pipeline_not_found(mocker): + app_generate_entity = _build_app_generate_entity() + app_generate_entity.invoke_from = InvokeFrom.WEB_APP + app_generate_entity.single_iteration_run = None + app_generate_entity.single_loop_run = None + + query = MagicMock() + query.where.return_value.first.return_value = None + + session = MagicMock() + session.query.return_value = query + mocker.patch.object(module.db, "session", session) + + runner = PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(), + variable_loader=MagicMock(), + workflow=MagicMock(), + system_user_id="sys", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + with pytest.raises(ValueError): + runner.run() + + +def test_run_workflow_not_initialized(mocker): + app_generate_entity = _build_app_generate_entity() + + pipeline = MagicMock(id="pipe") + query_pipeline = MagicMock() + query_pipeline.where.return_value.first.return_value = pipeline + + session = MagicMock() + session.query.return_value = query_pipeline + mocker.patch.object(module.db, "session", session) + + runner = PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(), + variable_loader=MagicMock(), + workflow=MagicMock(), + system_user_id="sys", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + runner.get_workflow = MagicMock(return_value=None) + + with pytest.raises(ValueError): + runner.run() + + +def test_run_single_iteration_path(mocker): + app_generate_entity = _build_app_generate_entity() + app_generate_entity.single_iteration_run = MagicMock() + + pipeline = MagicMock(id="pipe") + query_pipeline = MagicMock() + query_pipeline.where.return_value.first.return_value = pipeline + + query_end_user = MagicMock() + query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess") + + session = MagicMock() + session.query.side_effect = [query_end_user, query_pipeline] + mocker.patch.object(module.db, "session", session) + + runner = PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(), + variable_loader=MagicMock(), + workflow=MagicMock(), + system_user_id="sys", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT) + runner.get_workflow = MagicMock( + return_value=MagicMock( + id="wf", + tenant_id="tenant", + app_id="pipe", + graph_dict={}, + type="rag-pipeline", + version="v1", + ) + ) + runner._prepare_single_node_execution = MagicMock(return_value=("graph", "pool", "state")) + runner._update_document_status = MagicMock() + runner._handle_event = MagicMock() + + workflow_entry = MagicMock() + workflow_entry.graph_engine = MagicMock() + workflow_entry.run.return_value = [MagicMock()] + mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry) + + mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock()) + + runner.run() + + runner._prepare_single_node_execution.assert_called_once() + runner._handle_event.assert_called() + + +def test_run_normal_path_builds_graph(mocker): + app_generate_entity = _build_app_generate_entity() + + pipeline = MagicMock(id="pipe") + query_pipeline = MagicMock() + query_pipeline.where.return_value.first.return_value = pipeline + + query_end_user = MagicMock() + query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess") + + session = MagicMock() + session.query.side_effect = [query_end_user, query_pipeline] + mocker.patch.object(module.db, "session", session) + + workflow = MagicMock( + id="wf", + tenant_id="tenant", + app_id="pipe", + graph_dict={"nodes": [], "edges": []}, + environment_variables=[], + rag_pipeline_variables=[{"variable": "input1", "belong_to_node_id": "start"}], + type="rag-pipeline", + version="v1", + ) + + runner = PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(), + variable_loader=MagicMock(), + workflow=workflow, + system_user_id="sys", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT) + runner.get_workflow = MagicMock(return_value=workflow) + runner._init_rag_pipeline_graph = MagicMock(return_value="graph") + runner._update_document_status = MagicMock() + runner._handle_event = MagicMock() + + mocker.patch.object( + module.RAGPipelineVariable, + "model_validate", + return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"), + ) + mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + workflow_entry = MagicMock() + workflow_entry.graph_engine = MagicMock() + workflow_entry.run.return_value = [] + mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry) + mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock()) + + runner.run() + + runner._init_rag_pipeline_graph.assert_called_once() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index 43a97ae098..8f1baaa1e4 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,3 +1,5 @@ +from unittest.mock import MagicMock + import pytest from core.app.apps.base_app_generator import BaseAppGenerator @@ -366,3 +368,132 @@ def test_validate_inputs_optional_file_with_empty_string_ignores_default(): ) assert result is None + + +class TestBaseAppGeneratorExtras: + def test_prepare_user_inputs_converts_files_and_lists(self, monkeypatch): + base_app_generator = BaseAppGenerator() + + variables = [ + VariableEntity( + variable="file", + label="file", + type=VariableEntityType.FILE, + required=False, + allowed_file_types=[], + allowed_file_extensions=[], + allowed_file_upload_methods=[], + ), + VariableEntity( + variable="file_list", + label="file_list", + type=VariableEntityType.FILE_LIST, + required=False, + allowed_file_types=[], + allowed_file_extensions=[], + allowed_file_upload_methods=[], + ), + VariableEntity( + variable="json", + label="json", + type=VariableEntityType.JSON_OBJECT, + required=False, + ), + ] + + monkeypatch.setattr( + "core.app.apps.base_app_generator.file_factory.build_from_mapping", + lambda mapping, tenant_id, config, strict_type_validation=False: "file-object", + ) + monkeypatch.setattr( + "core.app.apps.base_app_generator.file_factory.build_from_mappings", + lambda mappings, tenant_id, config: ["file-1", "file-2"], + ) + + user_inputs = { + "file": {"id": "file-id"}, + "file_list": [{"id": "file-1"}, {"id": "file-2"}], + "json": {"key": "value"}, + } + + prepared = base_app_generator._prepare_user_inputs( + user_inputs=user_inputs, + variables=variables, + tenant_id="tenant-id", + ) + + assert prepared["file"] == "file-object" + assert prepared["file_list"] == ["file-1", "file-2"] + assert prepared["json"] == {"key": "value"} + + def test_prepare_user_inputs_rejects_invalid_dict_inputs(self): + base_app_generator = BaseAppGenerator() + variables = [ + VariableEntity( + variable="text", + label="text", + type=VariableEntityType.TEXT_INPUT, + required=False, + ) + ] + + with pytest.raises(ValueError, match="must be a string"): + base_app_generator._prepare_user_inputs( + user_inputs={"text": {"unexpected": "dict"}}, + variables=variables, + tenant_id="tenant-id", + ) + + def test_prepare_user_inputs_rejects_invalid_list_inputs(self): + base_app_generator = BaseAppGenerator() + variables = [ + VariableEntity( + variable="text", + label="text", + type=VariableEntityType.TEXT_INPUT, + required=False, + ) + ] + + with pytest.raises(ValueError, match="must be a string"): + base_app_generator._prepare_user_inputs( + user_inputs={"text": [{"unexpected": "dict"}]}, + variables=variables, + tenant_id="tenant-id", + ) + + def test_convert_to_event_stream(self): + base_app_generator = BaseAppGenerator() + + assert base_app_generator.convert_to_event_stream({"ok": True}) == {"ok": True} + + def _gen(): + yield {"delta": "hi"} + yield "ping" + + converted = list(base_app_generator.convert_to_event_stream(_gen())) + + assert converted[0].startswith("data: ") + assert "\n\n" in converted[0] + assert converted[1] == "event: ping\n\n" + + def test_get_draft_var_saver_factory_debugger(self): + from core.app.entities.app_invoke_entities import InvokeFrom + from dify_graph.enums import NodeType + from models import Account + + base_app_generator = BaseAppGenerator() + account = Account(name="Tester", email="tester@example.com") + account.id = "account-id" + account.tenant_id = "tenant-id" + + factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account) + saver = factory( + session=MagicMock(), + app_id="app-id", + node_id="node-id", + node_type=NodeType.START, + node_execution_id="node-exec-id", + ) + + assert saver is not None diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py new file mode 100644 index 0000000000..c6dc20ffc6 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py @@ -0,0 +1,61 @@ +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueErrorEvent + + +class DummyQueueManager(AppQueueManager): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.published = [] + + def _publish(self, event, pub_from): + self.published.append((event, pub_from)) + + +class TestBaseAppQueueManager: + def test_init_requires_user_id(self): + with pytest.raises(ValueError): + DummyQueueManager(task_id="t1", user_id="", invoke_from=InvokeFrom.SERVICE_API) + + def test_publish_error_records_event(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + manager.publish_error(ValueError("boom"), PublishFrom.TASK_PIPELINE) + + assert isinstance(manager.published[0][0], QueueErrorEvent) + + def test_set_stop_flag_checks_user(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.get.return_value = b"end-user-u1" + AppQueueManager.set_stop_flag(task_id="t1", invoke_from=InvokeFrom.SERVICE_API, user_id="u1") + + mock_redis.setex.assert_called_once() + + def test_set_stop_flag_no_user_check(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + AppQueueManager.set_stop_flag_no_user_check(task_id="t1") + + mock_redis.setex.assert_called_once() + + def test_is_stopped_reads_cache(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + mock_redis.get.return_value = b"1" + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + + assert manager._is_stopped() is True + + def test_check_for_sqlalchemy_models_raises(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + + bad = SimpleNamespace(_sa_instance_state=True) + with pytest.raises(TypeError): + manager._check_for_sqlalchemy_models(bad) diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py new file mode 100644 index 0000000000..aabeb54553 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -0,0 +1,442 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.entities import ( + AdvancedChatMessageEntity, + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + PromptTemplateEntity, +) +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError +from models.model import AppMode + + +class _DummyParameterRule: + def __init__(self, name: str, use_template: str | None = None) -> None: + self.name = name + self.use_template = use_template + + +class _QueueRecorder: + def __init__(self) -> None: + self.events: list[object] = [] + + def publish(self, event, pub_from): + _ = pub_from + self.events.append(event) + + +class TestAppRunner: + def test_recalc_llm_max_tokens_updates_parameters(self, monkeypatch): + runner = AppRunner() + + model_schema = SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 100}, + parameter_rules=[_DummyParameterRule("max_tokens")], + ) + model_config = SimpleNamespace( + provider_model_bundle=object(), + model="mock", + model_schema=model_schema, + parameters={"max_tokens": 30}, + ) + + monkeypatch.setattr( + "core.app.apps.base_app_runner.ModelInstance", + lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 80), + ) + + runner.recalc_llm_max_tokens(model_config, prompt_messages=[AssistantPromptMessage(content="hi")]) + + assert model_config.parameters["max_tokens"] == 20 + + def test_recalc_llm_max_tokens_returns_minus_one_when_no_context(self, monkeypatch): + runner = AppRunner() + + model_schema = SimpleNamespace( + model_properties={}, + parameter_rules=[_DummyParameterRule("max_tokens")], + ) + model_config = SimpleNamespace( + provider_model_bundle=object(), + model="mock", + model_schema=model_schema, + parameters={"max_tokens": 30}, + ) + + monkeypatch.setattr( + "core.app.apps.base_app_runner.ModelInstance", + lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 10), + ) + + assert runner.recalc_llm_max_tokens(model_config, prompt_messages=[]) == -1 + + def test_direct_output_streaming_publishes_chunks_and_end(self, monkeypatch): + runner = AppRunner() + queue = _QueueRecorder() + app_generate_entity = SimpleNamespace(model_conf=SimpleNamespace(model="mock"), stream=True) + + monkeypatch.setattr("core.app.apps.base_app_runner.time.sleep", lambda _: None) + + runner.direct_output( + queue_manager=queue, + app_generate_entity=app_generate_entity, + prompt_messages=[], + text="hi", + stream=True, + ) + + assert any(isinstance(event, QueueLLMChunkEvent) for event in queue.events) + assert isinstance(queue.events[-1], QueueMessageEndEvent) + + def test_handle_invoke_result_direct_publishes_end_event(self): + runner = AppRunner() + queue = _QueueRecorder() + llm_result = LLMResult( + model="mock", + prompt_messages=[], + message=AssistantPromptMessage(content="done"), + usage=LLMUsage.empty_usage(), + ) + + runner._handle_invoke_result( + invoke_result=llm_result, + queue_manager=queue, + stream=False, + ) + + assert isinstance(queue.events[-1], QueueMessageEndEvent) + + def test_handle_invoke_result_invalid_type_raises(self): + runner = AppRunner() + queue = _QueueRecorder() + + with pytest.raises(NotImplementedError): + runner._handle_invoke_result( + invoke_result=["unexpected"], + queue_manager=queue, + stream=True, + ) + + def test_organize_prompt_messages_simple_template(self, monkeypatch): + runner = AppRunner() + model_config = SimpleNamespace(mode="chat", stop=["STOP"]) + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="hello", + ) + + monkeypatch.setattr( + "core.app.apps.base_app_runner.SimplePromptTransform.get_prompt", + lambda self, **kwargs: (["simple-message"], ["simple-stop"]), + ) + + prompt_messages, stop = runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=model_config, + prompt_template_entity=prompt_template_entity, + inputs={}, + files=[], + query="q", + ) + + assert prompt_messages == ["simple-message"] + assert stop == ["simple-stop"] + + def test_organize_prompt_messages_advanced_completion_template(self, monkeypatch): + runner = AppRunner() + model_config = SimpleNamespace(mode="completion", stop=[""]) + captured: dict[str, object] = {} + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt="answer", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="U", assistant="A"), + ), + ) + + def _fake_advanced_prompt(self, **kwargs): + captured.update(kwargs) + return ["advanced-completion-message"] + + monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt) + + prompt_messages, stop = runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=model_config, + prompt_template_entity=prompt_template_entity, + inputs={}, + files=[], + query="q", + ) + + assert prompt_messages == ["advanced-completion-message"] + assert stop == [""] + memory_config = captured["memory_config"] + assert memory_config.role_prefix.user == "U" + assert memory_config.role_prefix.assistant == "A" + + def test_organize_prompt_messages_advanced_chat_template(self, monkeypatch): + runner = AppRunner() + model_config = SimpleNamespace(mode="chat", stop=[""]) + captured: dict[str, object] = {} + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity(text="hello", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="world", role=PromptMessageRole.ASSISTANT), + ] + ), + ) + + def _fake_advanced_prompt(self, **kwargs): + captured.update(kwargs) + return ["advanced-chat-message"] + + monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt) + + prompt_messages, stop = runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=model_config, + prompt_template_entity=prompt_template_entity, + inputs={}, + files=[], + query="q", + ) + + assert prompt_messages == ["advanced-chat-message"] + assert stop == [""] + assert len(captured["prompt_template"]) == 2 + + def test_organize_prompt_messages_advanced_missing_templates_raise(self): + runner = AppRunner() + + with pytest.raises(InvokeBadRequestError, match="Advanced completion prompt template is required"): + runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=SimpleNamespace(mode="completion", stop=[]), + prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED), + inputs={}, + files=[], + ) + + with pytest.raises(InvokeBadRequestError, match="Advanced chat prompt template is required"): + runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=SimpleNamespace(mode="chat", stop=[]), + prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED), + inputs={}, + files=[], + ) + + def test_handle_invoke_result_stream_routes_chunks_and_builds_message(self, monkeypatch): + runner = AppRunner() + queue = _QueueRecorder() + warning_logger = MagicMock() + monkeypatch.setattr("core.app.apps.base_app_runner._logger.warning", warning_logger) + + image_content = ImagePromptMessageContent( + url="https://example.com/image.png", format="png", mime_type="image/png" + ) + + def _stream(): + yield LLMResultChunk( + model="stream-model", + prompt_messages=[AssistantPromptMessage(content="prompt")], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage.model_construct( + content=[ + "a", + TextPromptMessageContent(data="b"), + SimpleNamespace(data="c"), + image_content, + ] + ), + ), + ) + + runner._handle_invoke_result( + invoke_result=_stream(), + queue_manager=queue, + stream=True, + agent=False, + ) + + assert isinstance(queue.events[0], QueueLLMChunkEvent) + assert isinstance(queue.events[-1], QueueMessageEndEvent) + assert queue.events[-1].llm_result.message.content == "abc" + warning_logger.assert_called_once() + + def test_handle_invoke_result_stream_agent_mode_handles_multimodal_errors(self, monkeypatch): + runner = AppRunner() + queue = _QueueRecorder() + exception_logger = MagicMock() + monkeypatch.setattr("core.app.apps.base_app_runner._logger.exception", exception_logger) + + monkeypatch.setattr( + runner, + "_handle_multimodal_image_content", + MagicMock(side_effect=RuntimeError("failed to save image")), + ) + usage = LLMUsage.empty_usage() + + def _stream(): + yield LLMResultChunk( + model="agent-model", + prompt_messages=[AssistantPromptMessage(content="prompt")], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[ + ImagePromptMessageContent( + url="https://example.com/image.png", + format="png", + mime_type="image/png", + ), + TextPromptMessageContent(data="done"), + ] + ), + usage=usage, + ), + ) + + runner._handle_invoke_result_stream( + invoke_result=_stream(), + queue_manager=queue, + agent=True, + message_id="message-id", + user_id="user-id", + tenant_id="tenant-id", + ) + + assert isinstance(queue.events[0], QueueAgentMessageEvent) + assert isinstance(queue.events[-1], QueueMessageEndEvent) + assert queue.events[-1].llm_result.usage == usage + exception_logger.assert_called_once() + + def test_handle_multimodal_image_content_fallback_return_branch(self, monkeypatch): + runner = AppRunner() + + class _ToggleBool: + def __init__(self, values: list[bool]): + self._values = values + self._index = 0 + + def __bool__(self): + value = self._values[min(self._index, len(self._values) - 1)] + self._index += 1 + return value + + content = SimpleNamespace( + url=_ToggleBool([False, False]), + base64_data=_ToggleBool([True, False]), + mime_type="image/png", + ) + + db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock(), refresh=MagicMock()) + monkeypatch.setattr("core.app.apps.base_app_runner.ToolFileManager", lambda: MagicMock()) + monkeypatch.setattr("core.app.apps.base_app_runner.db", SimpleNamespace(session=db_session)) + + queue_manager = SimpleNamespace(invoke_from=InvokeFrom.SERVICE_API, publish=MagicMock()) + + runner._handle_multimodal_image_content( + content=content, + message_id="message-id", + user_id="user-id", + tenant_id="tenant-id", + queue_manager=queue_manager, + ) + + db_session.add.assert_not_called() + queue_manager.publish.assert_not_called() + + def test_check_hosting_moderation_direct_output_called(self, monkeypatch): + runner = AppRunner() + queue = _QueueRecorder() + app_generate_entity = SimpleNamespace(stream=False) + + monkeypatch.setattr( + "core.app.apps.base_app_runner.HostingModerationFeature.check", + lambda self, application_generate_entity, prompt_messages: True, + ) + direct_output = MagicMock() + monkeypatch.setattr(runner, "direct_output", direct_output) + + result = runner.check_hosting_moderation( + application_generate_entity=app_generate_entity, + queue_manager=queue, + prompt_messages=[], + ) + + assert result is True + assert direct_output.called + + def test_fill_in_inputs_from_external_data_tools(self, monkeypatch): + runner = AppRunner() + monkeypatch.setattr( + "core.app.apps.base_app_runner.ExternalDataFetch.fetch", + lambda self, tenant_id, app_id, external_data_tools, inputs, query: {"foo": "bar"}, + ) + + result = runner.fill_in_inputs_from_external_data_tools( + tenant_id="tenant", + app_id="app", + external_data_tools=[], + inputs={}, + query="q", + ) + + assert result == {"foo": "bar"} + + def test_moderation_for_inputs_returns_result(self, monkeypatch): + runner = AppRunner() + monkeypatch.setattr( + "core.app.apps.base_app_runner.InputModeration.check", + lambda self, app_id, tenant_id, app_config, inputs, query, message_id, trace_manager: (True, {}, ""), + ) + app_generate_entity = SimpleNamespace(app_config=SimpleNamespace(), trace_manager=None) + + result = runner.moderation_for_inputs( + app_id="app", + tenant_id="tenant", + app_generate_entity=app_generate_entity, + inputs={}, + query="q", + message_id="msg", + ) + + assert result == (True, {}, "") + + def test_query_app_annotations_to_reply(self, monkeypatch): + runner = AppRunner() + monkeypatch.setattr( + "core.app.apps.base_app_runner.AnnotationReplyFeature.query", + lambda self, app_record, message, query, user_id, invoke_from: "reply", + ) + + response = runner.query_app_annotations_to_reply( + app_record=SimpleNamespace(), + message=SimpleNamespace(), + query="hello", + user_id="user", + invoke_from=InvokeFrom.WEB_APP, + ) + + assert response == "reply" diff --git a/api/tests/unit_tests/core/app/apps/test_exc.py b/api/tests/unit_tests/core/app/apps/test_exc.py new file mode 100644 index 0000000000..e41c78e89e --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_exc.py @@ -0,0 +1,7 @@ +from core.app.apps.exc import GenerateTaskStoppedError + + +class TestAppsExceptions: + def test_generate_task_stopped_error(self): + err = GenerateTaskStoppedError("stopped") + assert str(err) == "stopped" diff --git a/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py index 87b8dc51e7..1250ac5ecf 100644 --- a/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py @@ -13,9 +13,11 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.app.apps import message_based_app_generator +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from models.model import AppMode, Conversation, Message +from services.errors.app_model_config import AppModelConfigBrokenError class DummyModelConf: @@ -125,3 +127,55 @@ def test_init_generate_records_sets_conversation_fields_for_chat_entity(): assert entity.conversation_id == "generated-conversation-id" assert entity.is_new_conversation is True assert conversation.id == "generated-conversation-id" + + +class TestMessageBasedAppGeneratorExtras: + def test_handle_response_closed_file_raises_stopped(self, monkeypatch): + generator = MessageBasedAppGenerator() + + class _Pipeline: + def __init__(self, **kwargs) -> None: + _ = kwargs + + def process(self): + raise ValueError("I/O operation on closed file.") + + monkeypatch.setattr( + "core.app.apps.message_based_app_generator.EasyUIBasedGenerateTaskPipeline", + _Pipeline, + ) + + with pytest.raises(GenerateTaskStoppedError): + generator._handle_response( + application_generate_entity=_make_chat_generate_entity(_make_app_config(AppMode.CHAT)), + queue_manager=SimpleNamespace(), + conversation=SimpleNamespace(id="conv"), + message=SimpleNamespace(id="msg"), + user=SimpleNamespace(), + stream=False, + ) + + def test_get_app_model_config_requires_valid_config(self, monkeypatch): + generator = MessageBasedAppGenerator() + app_model = SimpleNamespace(id="app", app_model_config_id=None, app_model_config=None) + + with pytest.raises(AppModelConfigBrokenError): + generator._get_app_model_config(app_model, conversation=None) + + conversation = SimpleNamespace(app_model_config_id="missing-id") + monkeypatch.setattr( + message_based_app_generator, "db", SimpleNamespace(session=SimpleNamespace(scalar=lambda _: None)) + ) + + with pytest.raises(AppModelConfigBrokenError): + generator._get_app_model_config(app_model=SimpleNamespace(id="app"), conversation=conversation) + + def test_get_conversation_introduction_handles_missing_inputs(self): + app_config = _make_app_config(AppMode.CHAT) + app_config.additional_features.opening_statement = "Hello {{name}}" + entity = _make_chat_generate_entity(app_config) + entity.inputs = {} + + generator = MessageBasedAppGenerator() + + assert generator._get_conversation_introduction(entity) == "Hello {name}" diff --git a/api/tests/unit_tests/core/app/apps/test_message_based_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/test_message_based_app_queue_manager.py new file mode 100644 index 0000000000..847ad0ce9b --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_message_based_app_queue_manager.py @@ -0,0 +1,65 @@ +from unittest.mock import Mock, patch + +import pytest + +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueErrorEvent, QueueMessageEndEvent, QueueStopEvent + + +class TestMessageBasedAppQueueManager: + def test_publish_stops_on_terminal_events(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = MessageBasedAppQueueManager( + task_id="t1", + user_id="u1", + invoke_from=InvokeFrom.SERVICE_API, + conversation_id="c1", + app_mode="chat", + message_id="m1", + ) + + manager.stop_listen = Mock() + manager._is_stopped = Mock(return_value=False) + + manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), Mock()) + manager.stop_listen.assert_called_once() + + def test_publish_raises_when_stopped(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = MessageBasedAppQueueManager( + task_id="t1", + user_id="u1", + invoke_from=InvokeFrom.SERVICE_API, + conversation_id="c1", + app_mode="chat", + message_id="m1", + ) + + manager._is_stopped = Mock(return_value=True) + + with pytest.raises(GenerateTaskStoppedError): + manager._publish(QueueErrorEvent(error=ValueError("boom")), PublishFrom.APPLICATION_MANAGER) + + def test_publish_enqueues_message_end(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = MessageBasedAppQueueManager( + task_id="t1", + user_id="u1", + invoke_from=InvokeFrom.SERVICE_API, + conversation_id="c1", + app_mode="chat", + message_id="m1", + ) + + manager._is_stopped = Mock(return_value=False) + manager.stop_listen = Mock() + + manager._publish(QueueMessageEndEvent(), PublishFrom.TASK_PIPELINE) + + assert manager._q.qsize() == 1 diff --git a/api/tests/unit_tests/core/app/apps/test_message_generator.py b/api/tests/unit_tests/core/app/apps/test_message_generator.py new file mode 100644 index 0000000000..25377e633e --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_message_generator.py @@ -0,0 +1,29 @@ +from unittest.mock import Mock, patch + +from core.app.apps.message_generator import MessageGenerator +from models.model import AppMode + + +class TestMessageGenerator: + def test_get_response_topic(self): + channel = Mock() + channel.topic.return_value = "topic" + + with patch("core.app.apps.message_generator.get_pubsub_broadcast_channel", return_value=channel): + topic = MessageGenerator.get_response_topic(AppMode.WORKFLOW, "run-1") + + assert topic == "topic" + expected_key = MessageGenerator._make_channel_key(AppMode.WORKFLOW, "run-1") + channel.topic.assert_called_once_with(expected_key) + + def test_retrieve_events_passes_arguments(self): + with ( + patch("core.app.apps.message_generator.MessageGenerator.get_response_topic", return_value="topic"), + patch( + "core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}]) + ) as mock_stream, + ): + events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2)) + + assert events == [{"event": "ping"}] + mock_stream.assert_called_once() diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index e019a4b977..4f67d9cb56 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -8,6 +8,8 @@ from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.entities.pause_reason import SchedulingPause from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus @@ -22,7 +24,7 @@ from dify_graph.graph_events import ( NodeRunSucceededEvent, ) from dify_graph.node_events import NodeRunResult, PauseRequestedEvent -from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig +from dify_graph.nodes.base.entities import OutputVariableEntity from dify_graph.nodes.base.node import Node from dify_graph.nodes.end.entities import EndNodeData from dify_graph.nodes.start.entities import StartNodeData @@ -42,6 +44,7 @@ if "core.ops.ops_trace_manager" not in sys.modules: class _StubToolNodeData(BaseNodeData): + type: NodeType = NodeType.TOOL pause_on: bool = False @@ -88,16 +91,17 @@ class _StubToolNode(Node[_StubToolNodeData]): def _patch_tool_node(mocker): original_create_node = DifyNodeFactory.create_node - def _patched_create_node(self, node_config: dict[str, object]) -> Node: - node_data = node_config.get("data", {}) - if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value: + def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node: + typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + node_data = typed_node_config["data"] + if node_data.type == NodeType.TOOL: return _StubToolNode( - id=str(node_config["id"]), - config=node_config, + id=str(typed_node_config["id"]), + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, ) - return original_create_node(self, node_config) + return original_create_node(self, typed_node_config) mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) diff --git a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py index 7b5447c01e..a7714c56ce 100644 --- a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py +++ b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py @@ -6,6 +6,7 @@ import queue import pytest from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.streaming_utils import _normalize_terminal_events, stream_topic_events from core.app.entities.task_entities import StreamEvent from models.model import AppMode @@ -78,3 +79,30 @@ def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch): assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value with pytest.raises(StopIteration): next(generator) + + +def test_normalize_terminal_events_defaults(): + assert _normalize_terminal_events(None) == { + StreamEvent.WORKFLOW_FINISHED.value, + StreamEvent.WORKFLOW_PAUSED.value, + } + + +def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch): + topic = FakeTopic() + times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0] + + def fake_time(): + return times.pop(0) + + monkeypatch.setattr("core.app.apps.streaming_utils.time.time", fake_time) + + generator = stream_topic_events( + topic=topic, + idle_timeout=10.0, + ping_interval=1.0, + ) + + assert next(generator) == StreamEvent.PING.value + # next receive yields None -> ping interval triggers + assert next(generator) == StreamEvent.PING.value diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py new file mode 100644 index 0000000000..108b740344 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace + +import pytest + +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueIterationCompletedEvent, + QueueLoopCompletedEvent, + QueueTextChunkEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import NodeType +from dify_graph.graph_events import ( + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, +) +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable + + +class TestWorkflowBasedAppRunner: + def test_resolve_user_from(self): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + + assert runner._resolve_user_from(InvokeFrom.EXPLORE) == UserFrom.ACCOUNT + assert runner._resolve_user_from(InvokeFrom.DEBUGGER) == UserFrom.ACCOUNT + assert runner._resolve_user_from(InvokeFrom.WEB_APP) == UserFrom.END_USER + + def test_init_graph_validates_graph_structure(self): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + + runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default()), + start_at=0.0, + ) + + with pytest.raises(ValueError, match="nodes or edges not found"): + runner._init_graph( + graph_config={}, + graph_runtime_state=runtime_state, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with pytest.raises(ValueError, match="nodes in workflow graph must be a list"): + runner._init_graph( + graph_config={"nodes": {}, "edges": []}, + graph_runtime_state=runtime_state, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with pytest.raises(ValueError, match="edges in workflow graph must be a list"): + runner._init_graph( + graph_config={"nodes": [], "edges": {}}, + graph_runtime_state=runtime_state, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + def test_prepare_single_node_execution_requires_run(self): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + + workflow = SimpleNamespace(environment_variables=[], graph_dict={}) + + with pytest.raises(ValueError, match="Neither single_iteration_run nor single_loop_run"): + runner._prepare_single_node_execution(workflow, None, None) + + def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default()), + start_at=0.0, + ) + + graph_config = { + "nodes": [{"id": "node-1", "data": {"type": "start", "version": "1"}}], + "edges": [], + } + workflow = SimpleNamespace(tenant_id="tenant", id="workflow", graph_dict=graph_config) + + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.Graph.init", + lambda **kwargs: SimpleNamespace(), + ) + + class _NodeCls: + @staticmethod + def extract_variable_selector_to_variable_mapping(graph_config, config): + return {} + + from core.app.apps import workflow_app_runner + + monkeypatch.setattr( + workflow_app_runner, + "resolve_workflow_node_class", + lambda **_kwargs: _NodeCls, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.load_into_variable_pool", + lambda **kwargs: None, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool", + lambda **kwargs: None, + ) + + graph, variable_pool = runner._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id="node-1", + user_inputs={}, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="iteration_id", + node_type_label="iteration", + ) + + assert graph is not None + assert variable_pool is graph_runtime_state.variable_pool + + def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch): + published: list[object] = [] + + class _QueueManager: + def publish(self, event, publish_from): + published.append((event, publish_from)) + + runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default()), + start_at=0.0, + ) + graph_runtime_state.register_paused_node("node-1") + workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) + + emails: list[dict] = [] + + class _Dispatch: + def apply_async(self, *, kwargs, queue): + emails.append({"kwargs": kwargs, "queue": queue}) + + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.dispatch_human_input_email_task", + _Dispatch(), + ) + + reason = HumanInputRequired( + form_id="form", + form_content="content", + node_id="node-1", + node_title="Node", + ) + + runner._handle_event(workflow_entry, GraphRunStartedEvent()) + runner._handle_event(workflow_entry, GraphRunSucceededEvent(outputs={"ok": True})) + runner._handle_event(workflow_entry, GraphRunPausedEvent(reasons=[reason], outputs={})) + + assert any(isinstance(event, QueueWorkflowStartedEvent) for event, _ in published) + assert any(isinstance(event, QueueWorkflowSucceededEvent) for event, _ in published) + paused_event = next(event for event, _ in published if isinstance(event, QueueWorkflowPausedEvent)) + assert paused_event.paused_nodes == ["node-1"] + assert emails + + def test_handle_node_events_publishes_queue_events(self): + published: list[object] = [] + + class _QueueManager: + def publish(self, event, publish_from): + published.append(event) + + runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default()), + start_at=0.0, + ) + workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) + + runner._handle_event( + workflow_entry, + NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=NodeType.START, + node_title="Start", + start_at=datetime.utcnow(), + ), + ) + runner._handle_event( + workflow_entry, + NodeRunStreamChunkEvent( + id="exec", + node_id="node", + node_type=NodeType.START, + selector=["node", "text"], + chunk="hi", + is_final=False, + ), + ) + runner._handle_event( + workflow_entry, + NodeRunAgentLogEvent( + id="exec", + node_id="node", + node_type=NodeType.START, + message_id="msg", + label="label", + node_execution_id="exec", + parent_id=None, + error=None, + status="done", + data={}, + metadata={}, + ), + ) + runner._handle_event( + workflow_entry, + NodeRunIterationSucceededEvent( + id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="Iter", + start_at=datetime.utcnow(), + inputs={}, + outputs={"ok": True}, + metadata={}, + steps=1, + ), + ) + runner._handle_event( + workflow_entry, + NodeRunLoopFailedEvent( + id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="Loop", + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + metadata={}, + steps=1, + error="boom", + ), + ) + + assert any(isinstance(event, QueueTextChunkEvent) for event in published) + assert any(isinstance(event, QueueAgentLogEvent) for event in published) + assert any(isinstance(event, QueueIterationCompletedEvent) for event in published) + assert any(isinstance(event, QueueLoopCompletedEvent) for event in published) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 2e0715e974..178e26118e 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -7,7 +7,9 @@ import pytest from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from models.workflow import Workflow @@ -105,3 +107,57 @@ def test_run_uses_single_node_execution_branch( assert entry_kwargs["invoke_from"] == InvokeFrom.DEBUGGER assert entry_kwargs["variable_pool"] is variable_pool assert entry_kwargs["graph_runtime_state"] is graph_runtime_state + + +def test_single_node_run_validates_target_node_config(monkeypatch) -> None: + runner = WorkflowBasedAppRunner( + queue_manager=MagicMock(spec=AppQueueManager), + variable_loader=MagicMock(), + app_id="app", + ) + + workflow = MagicMock(spec=Workflow) + workflow.id = "workflow" + workflow.tenant_id = "tenant" + workflow.graph_dict = { + "nodes": [ + { + "id": "loop-node", + "data": { + "type": "loop", + "title": "Loop", + "loop_count": 1, + "break_conditions": [], + "logical_operator": "and", + }, + } + ], + "edges": [], + } + + _, _, graph_runtime_state = _make_graph_state() + seen_configs: list[object] = [] + original_validate_python = NodeConfigDictAdapter.validate_python + + def record_validate_python(value: object): + seen_configs.append(value) + return original_validate_python(value) + + monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) + + with ( + patch("core.app.apps.workflow_app_runner.DifyNodeFactory"), + patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()), + patch("core.app.apps.workflow_app_runner.load_into_variable_pool"), + patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"), + ): + runner._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id="loop-node", + user_inputs={}, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="loop_id", + node_type_label="loop", + ) + + assert seen_configs == [workflow.graph_dict["nodes"][0]] diff --git a/api/tests/unit_tests/core/app/apps/workflow/__init__.py b/api/tests/unit_tests/core/app/apps/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_app_config_manager.py b/api/tests/unit_tests/core/app/apps/workflow/test_app_config_manager.py new file mode 100644 index 0000000000..f8dd6bf609 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_app_config_manager.py @@ -0,0 +1,61 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from models.model import AppMode + + +class TestWorkflowAppConfigManager: + def test_get_app_config(self): + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + workflow = SimpleNamespace(id="wf-1", features_dict={}) + + with ( + patch( + "core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.convert", + return_value=None, + ), + patch( + "core.app.apps.workflow.app_config_manager.WorkflowVariablesConfigManager.convert", + return_value=[], + ), + ): + app_config = WorkflowAppConfigManager.get_app_config(app_model, workflow) + + assert app_config.workflow_id == "wf-1" + assert app_config.app_mode == AppMode.WORKFLOW + + def test_config_validate_filters_keys(self): + def _add_key(key, value): + def _inner(*args, **kwargs): + # Support both positional and keyword arguments for config + if "config" in kwargs: + config = kwargs["config"] + elif len(args) > 0: + config = args[0] + else: + config = {} + config[key] = value + return config, [key] + + return _inner + + with ( + patch( + "core.app.apps.workflow.app_config_manager.FileUploadConfigManager.validate_and_set_defaults", + side_effect=_add_key("file_upload", 1), + ), + patch( + "core.app.apps.workflow.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults", + side_effect=_add_key("text_to_speech", 2), + ), + patch( + "core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults", + side_effect=_add_key("sensitive_word_avoidance", 3), + ), + ): + filtered = WorkflowAppConfigManager.config_validate(tenant_id="t1", config={}) + + assert filtered["file_upload"] == 1 + assert filtered["text_to_speech"] == 2 + assert filtered["sensitive_word_avoidance"] == 3 diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py b/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py new file mode 100644 index 0000000000..09ad078a70 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.ops.ops_trace_manager import TraceQueueManager +from models.model import AppMode + + +class TestWorkflowAppGeneratorValidation: + def test_should_prepare_user_inputs(self): + generator = WorkflowAppGenerator() + + assert generator._should_prepare_user_inputs({}) is True + assert generator._should_prepare_user_inputs({SKIP_PREPARE_USER_INPUTS_KEY: True}) is False + + def test_single_iteration_generate_validates_args(self): + generator = WorkflowAppGenerator() + + with pytest.raises(ValueError, match="node_id is required"): + generator.single_iteration_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="", + user=SimpleNamespace(), + args={"inputs": {}}, + streaming=False, + ) + + with pytest.raises(ValueError, match="inputs is required"): + generator.single_iteration_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="node", + user=SimpleNamespace(), + args={}, + streaming=False, + ) + + def test_single_loop_generate_validates_args(self): + generator = WorkflowAppGenerator() + + with pytest.raises(ValueError, match="node_id is required"): + generator.single_loop_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="", + user=SimpleNamespace(), + args=SimpleNamespace(inputs={}), + streaming=False, + ) + + with pytest.raises(ValueError, match="inputs is required"): + generator.single_loop_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="node", + user=SimpleNamespace(), + args=SimpleNamespace(inputs=None), + streaming=False, + ) + + +class TestWorkflowAppGeneratorHandleResponse: + def test_handle_response_closed_file_raises_stopped(self, monkeypatch): + generator = WorkflowAppGenerator() + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_execution_id="run-id", + call_depth=0, + ) + + class _Pipeline: + def __init__(self, **kwargs) -> None: + _ = kwargs + + def process(self): + raise ValueError("I/O operation on closed file.") + + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppGenerateTaskPipeline", + _Pipeline, + ) + + with pytest.raises(GenerateTaskStoppedError): + generator._handle_response( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(), + queue_manager=SimpleNamespace(), + user=SimpleNamespace(), + draft_var_saver_factory=lambda **kwargs: None, + stream=False, + ) + + +class TestWorkflowAppGeneratorGenerate: + def test_generate_skips_prepare_inputs_when_flag_set(self, monkeypatch): + generator = WorkflowAppGenerator() + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppConfigManager.get_app_config", + lambda app_model, workflow: app_config, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.FileUploadConfigManager.convert", + lambda features_dict, is_vision=False: None, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.file_factory.build_from_mappings", + lambda **kwargs: [], + ) + DummyTraceQueueManager = type( + "_DummyTraceQueueManager", + (TraceQueueManager,), + { + "__init__": lambda self, app_id=None, user_id=None: ( + setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id) + ) + }, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.TraceQueueManager", + DummyTraceQueueManager, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.sessionmaker", + lambda **kwargs: SimpleNamespace(), + ) + + prepare_inputs = pytest.fail + monkeypatch.setattr(generator, "_prepare_user_inputs", lambda **kwargs: prepare_inputs()) + + monkeypatch.setattr(generator, "_generate", lambda **kwargs: {"ok": True}) + + result = generator.generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=SimpleNamespace(features_dict={}), + user=SimpleNamespace(id="user", session_id="session"), + args={"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: True}, + invoke_from=InvokeFrom.WEB_APP, + streaming=False, + call_depth=0, + ) + + assert result == {"ok": True} diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/workflow/test_app_queue_manager.py new file mode 100644 index 0000000000..6133be9867 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_app_queue_manager.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import pytest + +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueMessageEndEvent, QueuePingEvent + + +class TestWorkflowAppQueueManager: + def test_publish_stop_events_trigger_stop(self): + manager = WorkflowAppQueueManager( + task_id="task", + user_id="user", + invoke_from=InvokeFrom.DEBUGGER, + app_mode="workflow", + ) + manager._is_stopped = lambda: True + + with pytest.raises(GenerateTaskStoppedError): + manager._publish(QueueMessageEndEvent(llm_result=None), PublishFrom.APPLICATION_MANAGER) + + def test_publish_non_stop_event_does_not_raise(self): + manager = WorkflowAppQueueManager( + task_id="task", + user_id="user", + invoke_from=InvokeFrom.DEBUGGER, + app_mode="workflow", + ) + + manager._publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_errors.py b/api/tests/unit_tests/core/app/apps/workflow/test_errors.py new file mode 100644 index 0000000000..7461e06833 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_errors.py @@ -0,0 +1,9 @@ +from core.app.apps.workflow.errors import WorkflowPausedInBlockingModeError + + +class TestWorkflowErrors: + def test_workflow_paused_in_blocking_mode_error_attributes(self): + err = WorkflowPausedInBlockingModeError() + assert err.error_code == "workflow_paused_in_blocking_mode" + assert err.code == 400 + assert "blocking response mode" in err.description diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py new file mode 100644 index 0000000000..62e94a7580 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -0,0 +1,133 @@ +from collections.abc import Generator + +from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter +from core.app.entities.task_entities import ( + ErrorStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) +from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + + +class TestWorkflowGenerateResponseConverter: + def test_blocking_full_response(self): + blocking = WorkflowAppBlockingResponse( + task_id="t1", + workflow_run_id="r1", + data=WorkflowAppBlockingResponse.Data( + id="exec-1", + workflow_id="wf-1", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={"ok": True}, + error=None, + elapsed_time=1.2, + total_tokens=10, + total_steps=2, + created_at=1, + finished_at=2, + ), + ) + response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking) + assert response["workflow_run_id"] == "r1" + + def test_stream_simple_response_node_events(self): + node_start = NodeStartStreamResponse( + task_id="t1", + workflow_run_id="r1", + data=NodeStartStreamResponse.Data( + id="e1", + node_id="n1", + node_type="answer", + title="Answer", + index=1, + created_at=1, + ), + ) + node_finish = NodeFinishStreamResponse( + task_id="t1", + workflow_run_id="r1", + data=NodeFinishStreamResponse.Data( + id="e1", + node_id="n1", + node_type="answer", + title="Answer", + index=1, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + elapsed_time=0.1, + created_at=1, + finished_at=2, + ), + ) + + def stream() -> Generator[WorkflowAppStreamResponse, None, None]: + yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=PingStreamResponse(task_id="t1")) + yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_start) + yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_finish) + yield WorkflowAppStreamResponse( + workflow_run_id="r1", stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")) + ) + + converted = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream())) + assert converted[0] == "ping" + assert converted[1]["event"] == "node_started" + assert converted[2]["event"] == "node_finished" + assert converted[3]["event"] == "error" + + def test_convert_stream_simple_response_handles_ping_and_nodes(self): + def _gen(): + yield WorkflowAppStreamResponse(stream_response=PingStreamResponse(task_id="task")) + yield WorkflowAppStreamResponse( + workflow_run_id="run", + stream_response=NodeStartStreamResponse( + task_id="task", + workflow_run_id="run", + data=NodeStartStreamResponse.Data( + id="node-exec", + node_id="node", + node_type="start", + title="Start", + index=1, + created_at=1, + ), + ), + ) + yield WorkflowAppStreamResponse( + workflow_run_id="run", + stream_response=NodeFinishStreamResponse( + task_id="task", + workflow_run_id="run", + data=NodeFinishStreamResponse.Data( + id="node-exec", + node_id="node", + node_type="start", + title="Start", + index=1, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={}, + created_at=1, + finished_at=2, + elapsed_time=1.0, + error=None, + ), + ), + ) + + chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(_gen())) + + assert chunks[0] == "ping" + assert chunks[1]["event"] == "node_started" + assert chunks[2]["event"] == "node_finished" + + def test_convert_stream_full_response_handles_error(self): + def _gen(): + yield WorkflowAppStreamResponse( + workflow_run_id="run", + stream_response=ErrorStreamResponse(task_id="task", err=ValueError("boom")), + ) + + chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(_gen())) + + assert chunks[0]["event"] == "error" diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py new file mode 100644 index 0000000000..b37f7a8120 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -0,0 +1,868 @@ +from __future__ import annotations + +from contextlib import contextmanager +from datetime import datetime +from types import SimpleNamespace + +import pytest + +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueErrorEvent, + QueueHumanInputFormFilledEvent, + QueueHumanInputFormTimeoutEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueLoopCompletedEvent, + QueueLoopNextEvent, + QueueLoopStartEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueuePingEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.app.entities.task_entities import ( + ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, + PingStreamResponse, + WorkflowFinishStreamResponse, + WorkflowPauseStreamResponse, + WorkflowStartStreamResponse, +) +from core.base.tts.app_generator_tts_publisher import AudioTrunk +from dify_graph.enums import NodeType, WorkflowExecutionStatus +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from models.enums import CreatorUserRole +from models.model import AppMode, EndUser + + +def _make_pipeline(): + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + trace_manager=None, + workflow_execution_id="run-id", + extras={}, + call_depth=0, + ) + workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + user = SimpleNamespace(id="user", session_id="session") + + pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None), + user=user, + stream=False, + draft_var_saver_factory=lambda **kwargs: None, + ) + + return pipeline + + +class TestWorkflowGenerateTaskPipeline: + def test_to_blocking_response_handles_pause(self): + pipeline = _make_pipeline() + + def _gen(): + yield WorkflowPauseStreamResponse( + task_id="task", + workflow_run_id="run", + data=WorkflowPauseStreamResponse.Data( + workflow_run_id="run", + status=WorkflowExecutionStatus.PAUSED, + outputs={}, + created_at=1, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.status == WorkflowExecutionStatus.PAUSED + + def test_to_blocking_response_handles_finish(self): + pipeline = _make_pipeline() + + def _gen(): + yield WorkflowFinishStreamResponse( + task_id="task", + workflow_run_id="run", + data=WorkflowFinishStreamResponse.Data( + id="run", + workflow_id="workflow-id", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={"ok": True}, + error=None, + elapsed_time=1.0, + total_tokens=5, + total_steps=2, + created_at=1, + finished_at=2, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.outputs == {"ok": True} + + def test_listen_audio_msg_returns_audio_stream(self): + pipeline = _make_pipeline() + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data")) + + response = pipeline._listen_audio_msg(publisher=publisher, task_id="task") + + assert isinstance(response, MessageAudioStreamResponse) + + def test_handle_ping_event(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task") + + responses = list(pipeline._handle_ping_event(QueuePingEvent())) + + assert isinstance(responses[0], PingStreamResponse) + + def test_handle_error_event(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline._base_task_pipeline.error_to_stream_response = lambda err: err + + responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom")))) + + assert isinstance(responses[0], ValueError) + + def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + monkeypatch.setattr(pipeline, "_save_workflow_app_log", lambda **kwargs: None) + + responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent())) + + assert pipeline._workflow_execution_id == "run-id" + assert responses == ["started"] + + def test_handle_node_succeeded_event_saves_output(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done" + pipeline._save_output_for_event = lambda event, node_execution_id: None + pipeline._workflow_execution_id = "run-id" + + event = QueueNodeSucceededEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.START, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + ) + + responses = list(pipeline._handle_node_succeeded_event(event)) + + assert responses == ["done"] + + def test_handle_workflow_failed_event_yields_error(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" + pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline._base_task_pipeline.error_to_stream_response = lambda err: err + + responses = list( + pipeline._handle_workflow_failed_and_stop_events(QueueWorkflowFailedEvent(error="fail", exceptions_count=1)) + ) + + assert responses[0] == "finish" + + def test_handle_text_chunk_event_publishes_tts(self): + pipeline = _make_pipeline() + published: list[object] = [] + + class _Publisher: + def publish(self, message): + published.append(message) + + event = QueueTextChunkEvent(text="hi", from_variable_selector=["x"]) + queue_message = SimpleNamespace(event=event) + + responses = list( + pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message) + ) + + assert responses[0].data.text == "hi" + assert published == [queue_message] + + def test_dispatch_event_handles_node_failed(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done" + + event = QueueNodeFailedEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.START, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="err", + ) + + assert list(pipeline._dispatch_event(event)) == ["done"] + + def test_handle_stop_event_yields_finish(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" + + responses = list( + pipeline._handle_workflow_failed_and_stop_events( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) + ) + ) + + assert responses == ["finish"] + + def test_save_workflow_app_log_created_from(self): + pipeline = _make_pipeline() + pipeline._application_generate_entity.invoke_from = InvokeFrom.SERVICE_API + pipeline._user_id = "user" + added: list[object] = [] + + class _Session: + def add(self, item): + added.append(item) + + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id") + + assert added + + def test_iteration_loop_and_human_input_handlers(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: "iter" + pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "next" + pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: "done" + pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop" + pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next" + pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done" + pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled" + pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout" + pipeline._workflow_response_converter.handle_agent_log = lambda **kwargs: "log" + + iter_start = QueueIterationStartEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + iter_next = QueueIterationNextEvent( + index=1, + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + node_run_index=1, + ) + iter_done = QueueIterationCompletedEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + loop_start = QueueLoopStartEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + loop_next = QueueLoopNextEvent( + index=1, + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + node_run_index=1, + ) + loop_done = QueueLoopCompletedEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + filled_event = QueueHumanInputFormFilledEvent( + node_execution_id="exec", + node_id="node", + node_type=NodeType.LLM, + node_title="title", + rendered_content="content", + action_id="action", + action_text="action", + ) + timeout_event = QueueHumanInputFormTimeoutEvent( + node_id="node", + node_type=NodeType.LLM, + node_title="title", + expiration_time=datetime.utcnow(), + ) + agent_event = QueueAgentLogEvent( + id="log", + label="label", + node_execution_id="exec", + parent_id=None, + error=None, + status="done", + data={}, + metadata={}, + node_id="node", + ) + + assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter"] + assert list(pipeline._handle_iteration_next_event(iter_next)) == ["next"] + assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["done"] + assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop"] + assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"] + assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"] + assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"] + assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"] + assert list(pipeline._handle_agent_log_event(agent_event)) == ["log"] + + def test_wrapper_process_stream_response_emits_audio_end(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._workflow_features_dict = { + "text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"} + } + pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")]) + + class _Publisher: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def check_and_get_audio(self): + self.calls += 1 + if self.calls == 1: + return AudioTrunk(status="stream", audio="data") + if self.calls == 2: + return None + return AudioTrunk(status="finish", audio="") + + def publish(self, message): + return None + + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher", + _Publisher, + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert any(isinstance(item, MessageAudioStreamResponse) for item in responses) + assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses) + + def test_init_with_end_user_sets_role_and_system_user(self): + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + files=[], + user_id="end-user-id", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + trace_manager=None, + workflow_execution_id="run-id", + extras={}, + call_depth=0, + ) + workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + queue_manager = SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None) + end_user = EndUser(tenant_id="tenant", type="session", name="user", session_id="session-id") + end_user.id = "end-user-id" + + pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=end_user, + stream=False, + draft_var_saver_factory=lambda **kwargs: None, + ) + + assert pipeline._created_by_role == CreatorUserRole.END_USER + assert pipeline._workflow_system_variables.user_id == "session-id" + + def test_process_returns_stream_and_blocking_variants(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.stream = True + pipeline._wrapper_process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")]) + + stream_response = list(pipeline.process()) + assert len(stream_response) == 1 + assert stream_response[0].workflow_run_id is None + + pipeline._base_task_pipeline.stream = False + pipeline._wrapper_process_stream_response = lambda **kwargs: iter( + [ + WorkflowFinishStreamResponse( + task_id="task", + workflow_run_id="run-id", + data=WorkflowFinishStreamResponse.Data( + id="run-id", + workflow_id="workflow-id", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={}, + error=None, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + created_at=1, + finished_at=2, + ), + ) + ] + ) + + blocking_response = pipeline.process() + assert blocking_response.workflow_run_id == "run-id" + + def test_to_blocking_response_handles_error_and_unexpected_end(self): + pipeline = _make_pipeline() + + def _error_gen(): + yield ErrorStreamResponse(task_id="task", err=ValueError("boom")) + + with pytest.raises(ValueError, match="boom"): + pipeline._to_blocking_response(_error_gen()) + + def _unexpected_gen(): + yield PingStreamResponse(task_id="task") + + with pytest.raises(ValueError, match="queue listening stopped unexpectedly"): + pipeline._to_blocking_response(_unexpected_gen()) + + def test_to_stream_response_tracks_workflow_run_id(self): + pipeline = _make_pipeline() + + def _gen(): + yield WorkflowStartStreamResponse( + task_id="task", + workflow_run_id="run-id", + data=WorkflowStartStreamResponse.Data( + id="run-id", + workflow_id="workflow-id", + inputs={}, + created_at=1, + ), + ) + yield PingStreamResponse(task_id="task") + + stream_responses = list(pipeline._to_stream_response(_gen())) + assert stream_responses[0].workflow_run_id == "run-id" + assert stream_responses[1].workflow_run_id == "run-id" + + def test_listen_audio_msg_returns_none_without_publisher(self): + pipeline = _make_pipeline() + assert pipeline._listen_audio_msg(publisher=None, task_id="task") is None + + def test_wrapper_process_stream_response_without_tts(self): + pipeline = _make_pipeline() + pipeline._workflow_features_dict = {} + pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")]) + + responses = list(pipeline._wrapper_process_stream_response()) + assert responses == [PingStreamResponse(task_id="task")] + + def test_wrapper_process_stream_response_final_audio_none_then_finish(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._workflow_features_dict = { + "text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"} + } + pipeline._process_stream_response = lambda **kwargs: iter([]) + + sleep_spy = [] + + class _Publisher: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def check_and_get_audio(self): + self.calls += 1 + if self.calls == 1: + return None + return AudioTrunk(status="finish", audio="") + + def publish(self, message): + _ = message + + time_values = iter([0.0, 0.0, 0.2]) + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: next(time_values)) + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.time.sleep", lambda _: sleep_spy.append(True) + ) + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher", + _Publisher, + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert sleep_spy + assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses) + + def test_wrapper_process_stream_response_handles_audio_exception(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._workflow_features_dict = { + "text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"} + } + pipeline._process_stream_response = lambda **kwargs: iter([]) + + class _Publisher: + def __init__(self, *args, **kwargs): + self.called = False + + def check_and_get_audio(self): + if not self.called: + self.called = True + raise RuntimeError("tts failure") + return AudioTrunk(status="finish", audio="") + + def publish(self, message): + _ = message + + logger_exception = [] + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: 0.0) + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.logger.exception", + lambda *args, **kwargs: logger_exception.append((args, kwargs)), + ) + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher", + _Publisher, + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert logger_exception + assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses) + + def test_database_session_rolls_back_on_error(self, monkeypatch): + pipeline = _make_pipeline() + calls = {"commit": 0, "rollback": 0} + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + calls["commit"] += 1 + + def rollback(self): + calls["rollback"] += 1 + + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) + + with pytest.raises(RuntimeError, match="db error"): + with pipeline._database_session(): + raise RuntimeError("db error") + + assert calls["commit"] == 0 + assert calls["rollback"] == 1 + + def test_node_retry_and_started_handlers_cover_none_and_value(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + + retry_event = QueueNodeRetryEvent( + node_execution_id="exec", + node_id="node", + node_title="title", + node_type=NodeType.LLM, + node_run_index=1, + start_at=datetime.utcnow(), + provider_type="provider", + provider_id="provider-id", + error="error", + retry_index=1, + ) + started_event = QueueNodeStartedEvent( + node_execution_id="exec", + node_id="node", + node_title="title", + node_type=NodeType.LLM, + node_run_index=1, + start_at=datetime.utcnow(), + provider_type="provider", + provider_id="provider-id", + ) + + pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: None + assert list(pipeline._handle_node_retry_event(retry_event)) == [] + pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: "retry" + assert list(pipeline._handle_node_retry_event(retry_event)) == ["retry"] + + pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: None + assert list(pipeline._handle_node_started_event(started_event)) == [] + pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: "started" + assert list(pipeline._handle_node_started_event(started_event)) == ["started"] + + def test_handle_node_exception_event_saves_output(self): + pipeline = _make_pipeline() + saved_ids: list[str] = [] + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed" + pipeline._save_output_for_event = lambda event, node_execution_id: saved_ids.append(node_execution_id) + + event = QueueNodeExceptionEvent( + node_execution_id="exec-id", + node_id="node", + node_type=NodeType.START, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="boom", + ) + + responses = list(pipeline._handle_node_failed_events(event)) + assert responses == ["failed"] + assert saved_ids == ["exec-id"] + + def test_success_partial_and_pause_handlers(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + + pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" + assert list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={}))) == ["finish"] + assert list( + pipeline._handle_workflow_partial_success_event( + QueueWorkflowPartialSuccessEvent(exceptions_count=2, outputs={}) + ) + ) == ["finish"] + + pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: [ + "pause-a", + "pause-b", + ] + pause_event = QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=["node"]) + assert list(pipeline._handle_workflow_paused_event(pause_event)) == ["pause-a", "pause-b"] + + def test_text_chunk_handler_returns_empty_when_text_missing(self): + pipeline = _make_pipeline() + event = QueueTextChunkEvent.model_construct(text=None, from_variable_selector=None) + assert list(pipeline._handle_text_chunk_event(event)) == [] + + def test_dispatch_event_direct_failed_and_unhandled_paths(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"]) + assert list(pipeline._dispatch_event(QueuePingEvent())) == ["ping"] + + pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["workflow-failed"]) + assert list(pipeline._dispatch_event(QueueWorkflowFailedEvent(error="failed", exceptions_count=1))) == [ + "workflow-failed" + ] + + assert list(pipeline._dispatch_event(SimpleNamespace())) == [] + + def test_process_stream_response_main_match_paths_and_cleanup(self): + pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( + [ + SimpleNamespace(event=QueueWorkflowStartedEvent()), + SimpleNamespace(event=QueueTextChunkEvent(text="hello")), + SimpleNamespace(event=QueuePingEvent()), + SimpleNamespace(event=QueueErrorEvent(error="e")), + ] + ) + pipeline._handle_workflow_started_event = lambda event, **kwargs: iter(["started"]) + pipeline._handle_text_chunk_event = lambda event, **kwargs: iter(["text"]) + pipeline._dispatch_event = lambda event, **kwargs: iter(["dispatched"]) + pipeline._handle_error_event = lambda event, **kwargs: iter(["error"]) + publisher_calls: list[object] = [] + + class _Publisher: + def publish(self, message): + publisher_calls.append(message) + + responses = list(pipeline._process_stream_response(tts_publisher=_Publisher())) + assert responses == ["started", "text", "dispatched", "error"] + assert publisher_calls == [None] + + def test_process_stream_response_break_paths(self): + pipeline = _make_pipeline() + + pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueWorkflowFailedEvent(error="fail", exceptions_count=1))] + ) + pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["failed"]) + assert list(pipeline._process_stream_response()) == ["failed"] + + pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=[]))] + ) + pipeline._handle_workflow_paused_event = lambda event, **kwargs: iter(["paused"]) + assert list(pipeline._process_stream_response()) == ["paused"] + + pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))] + ) + pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["stopped"]) + assert list(pipeline._process_stream_response()) == ["stopped"] + + def test_save_workflow_app_log_covers_invoke_from_variants(self): + pipeline = _make_pipeline() + pipeline._user_id = "user-id" + added: list[object] = [] + + class _Session: + def add(self, item): + added.append(item) + + pipeline._application_generate_entity.invoke_from = InvokeFrom.EXPLORE + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id") + assert added[-1].created_from == "installed-app" + + pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id") + assert added[-1].created_from == "web-app" + + count_before = len(added) + pipeline._application_generate_entity.invoke_from = InvokeFrom.DEBUGGER + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id") + assert len(added) == count_before + + pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id=None) + assert len(added) == count_before + + def test_save_output_for_event_writes_draft_variables(self, monkeypatch): + pipeline = _make_pipeline() + saver_calls: list[tuple[object, object]] = [] + captured_factory_args: dict[str, object] = {} + + class _Saver: + def save(self, process_data, outputs): + saver_calls.append((process_data, outputs)) + + def _factory(**kwargs): + captured_factory_args.update(kwargs) + return _Saver() + + class _Begin: + def __enter__(self): + return None + + def __exit__(self, exc_type, exc, tb): + return False + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def begin(self): + return _Begin() + + pipeline._draft_var_saver_factory = _factory + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) + + event = QueueNodeSucceededEvent( + node_execution_id="exec-id", + node_id="node-id", + node_type=NodeType.START, + in_loop_id="loop-id", + start_at=datetime.utcnow(), + process_data={"k": "v"}, + outputs={"out": 1}, + ) + pipeline._save_output_for_event(event=event, node_execution_id="exec-id") + + assert captured_factory_args["node_execution_id"] == "exec-id" + assert captured_factory_args["enclosing_node_id"] == "loop-id" + assert saver_calls == [({"k": "v"}, {"out": 1})] diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py new file mode 100644 index 0000000000..3759b6aa37 --- /dev/null +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -0,0 +1,390 @@ +import base64 +import queue +from unittest.mock import MagicMock + +import pytest + +from core.base.tts.app_generator_tts_publisher import ( + AppGeneratorTTSPublisher, + AudioTrunk, + _invoice_tts, + _process_future, +) + +# ========================= +# Fixtures +# ========================= + + +@pytest.fixture +def mock_model_instance(mocker): + model = mocker.MagicMock() + model.invoke_tts.return_value = [b"audio1", b"audio2"] + model.get_tts_voices.return_value = [{"value": "voice1"}, {"value": "voice2"}] + return model + + +@pytest.fixture +def mock_model_manager(mocker, mock_model_instance): + manager = mocker.MagicMock() + manager.get_default_model_instance.return_value = mock_model_instance + mocker.patch( + "core.base.tts.app_generator_tts_publisher.ModelManager", + return_value=manager, + ) + return manager + + +@pytest.fixture(autouse=True) +def patch_threads(mocker): + """Prevent real threads from starting during tests""" + mocker.patch("threading.Thread.start", return_value=None) + + +# ========================= +# AudioTrunk Tests +# ========================= + + +class TestAudioTrunk: + def test_audio_trunk_initialization(self): + trunk = AudioTrunk("responding", b"data") + assert trunk.status == "responding" + assert trunk.audio == b"data" + + +# ========================= +# _invoice_tts Tests +# ========================= + + +class TestInvoiceTTS: + @pytest.mark.parametrize( + "text", + [None, "", " "], + ) + def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance): + result = _invoice_tts(text, mock_model_instance, "tenant", "voice1") + assert result is None + mock_model_instance.invoke_tts.assert_not_called() + + def test_invoice_tts_valid_text(self, mock_model_instance): + result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1") + mock_model_instance.invoke_tts.assert_called_once_with( + content_text="hello", + user="responding_tts", + tenant_id="tenant", + voice="voice1", + ) + assert result == [b"audio1", b"audio2"] + + +# ========================= +# _process_future Tests +# ========================= + + +class TestProcessFuture: + def test_process_future_normal_flow(self): + future_queue = queue.Queue() + audio_queue = queue.Queue() + + future = MagicMock() + future.result.return_value = [b"abc"] + + future_queue.put(future) + future_queue.put(None) + + _process_future(future_queue, audio_queue) + + first = audio_queue.get() + assert first.status == "responding" + assert first.audio == base64.b64encode(b"abc") + + finish = audio_queue.get() + assert finish.status == "finish" + + def test_process_future_empty_result(self): + future_queue = queue.Queue() + audio_queue = queue.Queue() + + future = MagicMock() + future.result.return_value = None + + future_queue.put(future) + future_queue.put(None) + + _process_future(future_queue, audio_queue) + + finish = audio_queue.get() + assert finish.status == "finish" + + def test_process_future_exception(self, mocker): + future_queue = queue.Queue() + audio_queue = queue.Queue() + + future = MagicMock() + future.result.side_effect = Exception("error") + + future_queue.put(future) + + _process_future(future_queue, audio_queue) + + finish = audio_queue.get() + assert finish.status == "finish" + + +# ========================= +# AppGeneratorTTSPublisher Tests +# ========================= + + +class TestAppGeneratorTTSPublisher: + def test_initialization_valid_voice(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + assert publisher.voice == "voice1" + assert publisher.max_sentence == 2 + assert publisher.msg_text == "" + + def test_initialization_invalid_voice_fallback(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "invalid_voice") + assert publisher.voice == "voice1" + + def test_publish_puts_message_in_queue(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + message = MagicMock() + publisher.publish(message) + assert publisher._msg_queue.get() == message + + def test_check_and_get_audio_no_audio(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + result = publisher.check_and_get_audio() + assert result is None + + def test_check_and_get_audio_non_finish_event(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + trunk = AudioTrunk("responding", b"abc") + publisher._audio_queue.put(trunk) + + result = publisher.check_and_get_audio() + + assert result.status == "responding" + assert publisher._last_audio_event == trunk + + def test_check_and_get_audio_finish_event(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + finish_trunk = AudioTrunk("finish", b"") + publisher._audio_queue.put(finish_trunk) + + result = publisher.check_and_get_audio() + + assert result.status == "finish" + publisher.executor.shutdown.assert_called_once() + + def test_check_and_get_audio_cached_finish(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + publisher._last_audio_event = AudioTrunk("finish", b"") + + result = publisher.check_and_get_audio() + + assert result.status == "finish" + publisher.executor.shutdown.assert_called_once() + + @pytest.mark.parametrize( + ("text", "expected_sentences", "expected_remaining"), + [ + ("Hello world.", ["Hello world."], ""), + ("Hello world! How are you?", ["Hello world!", " How are you?"], ""), + ("No punctuation", [], "No punctuation"), + ("", [], ""), + ], + ) + def test_extract_sentence(self, mock_model_manager, text, expected_sentences, expected_remaining): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + sentences, remaining = publisher._extract_sentence(text) + assert sentences == expected_sentences + assert remaining == expected_remaining + + def test_runtime_handles_none_message_with_buffer(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + publisher.msg_text = "Hello." + + publisher._msg_queue.put(None) + publisher._runtime() + + publisher.executor.submit.assert_called_once() + + def test_runtime_handles_none_message_without_buffer(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + publisher.msg_text = " " + + publisher._msg_queue.put(None) + publisher._runtime() + + publisher.executor.submit.assert_not_called() + + def test_runtime_sentence_threshold_triggers_submit(self, mock_model_manager, mocker): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + # Force sentence extraction to hit threshold condition + mocker.patch.object( + publisher, + "_extract_sentence", + return_value=(["Hello world.", " Second sentence."], ""), + ) + + from core.app.entities.queue_entities import QueueTextChunkEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueTextChunkEvent) + event.event.text = "Hello world. Second sentence." + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.executor.submit.called + + def test_runtime_handles_text_chunk_event(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueTextChunkEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueTextChunkEvent) + event.event.text = "Hello world." + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.executor.submit.called + + def test_runtime_handles_node_succeeded_event_with_output(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueNodeSucceededEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueNodeSucceededEvent) + event.event.outputs = {"output": "Hello world."} + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.executor.submit.called + + def test_runtime_handles_node_succeeded_event_without_output(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueNodeSucceededEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueNodeSucceededEvent) + event.event.outputs = None + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + publisher.executor.submit.assert_not_called() + + def test_runtime_handles_agent_message_event_list_content(self, mock_model_manager, mocker): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueAgentMessageEvent + from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + ) + + chunk = LLMResultChunk( + model="model", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[ + TextPromptMessageContent(data="Hello "), + ImagePromptMessageContent(format="png", mime_type="image/png", base64_data="a"), + ] + ), + ), + ) + event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk)) + + mocker.patch.object(publisher, "_extract_sentence", return_value=([], "")) + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.msg_text == "Hello " + + def test_runtime_handles_agent_message_event_empty_content(self, mock_model_manager, mocker): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueAgentMessageEvent + from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + + chunk = LLMResultChunk( + model="model", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=""), + ), + ) + event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk)) + + mocker.patch.object(publisher, "_extract_sentence", return_value=([], "")) + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.msg_text == "" + + def test_runtime_resets_msg_text_when_text_tmp_not_str(self, mock_model_manager, mocker): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueTextChunkEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueTextChunkEvent) + event.event.text = "Hello world. Another sentence." + + mocker.patch.object(publisher, "_extract_sentence", return_value=(["A.", "B."], None)) + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.msg_text == "" + + def test_runtime_exception_path(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher._msg_queue = MagicMock() + publisher._msg_queue.get.side_effect = Exception("error") + + publisher._runtime() diff --git a/api/tests/unit_tests/core/callback_handler/test_agent_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_agent_tool_callback_handler.py new file mode 100644 index 0000000000..4c1aa33540 --- /dev/null +++ b/api/tests/unit_tests/core/callback_handler/test_agent_tool_callback_handler.py @@ -0,0 +1,197 @@ +from unittest.mock import MagicMock + +import pytest + +import core.callback_handler.agent_tool_callback_handler as module + +# ----------------------------- +# Fixtures +# ----------------------------- + + +@pytest.fixture +def enable_debug(mocker): + mocker.patch.object(module.dify_config, "DEBUG", True) + + +@pytest.fixture +def disable_debug(mocker): + mocker.patch.object(module.dify_config, "DEBUG", False) + + +@pytest.fixture +def mock_print(mocker): + return mocker.patch("builtins.print") + + +@pytest.fixture +def handler(): + return module.DifyAgentCallbackHandler(color="blue") + + +# ----------------------------- +# get_colored_text Tests +# ----------------------------- + + +class TestGetColoredText: + @pytest.mark.parametrize( + ("color", "expected_code"), + [ + ("blue", "36;1"), + ("yellow", "33;1"), + ("pink", "38;5;200"), + ("green", "32;1"), + ("red", "31;1"), + ], + ) + def test_get_colored_text_valid_colors(self, color, expected_code): + text = "hello" + result = module.get_colored_text(text, color) + assert expected_code in result + assert text in result + assert result.endswith("\u001b[0m") + + def test_get_colored_text_invalid_color_raises(self): + with pytest.raises(KeyError): + module.get_colored_text("hello", "invalid") + + def test_get_colored_text_empty_string(self): + result = module.get_colored_text("", "green") + assert "\u001b[" in result + + +# ----------------------------- +# print_text Tests +# ----------------------------- + + +class TestPrintText: + def test_print_text_without_color(self, mock_print): + module.print_text("hello") + mock_print.assert_called_once_with("hello", end="", file=None) + + def test_print_text_with_color(self, mocker, mock_print): + mock_get_color = mocker.patch( + "core.callback_handler.agent_tool_callback_handler.get_colored_text", + return_value="colored_text", + ) + + module.print_text("hello", color="green") + + mock_get_color.assert_called_once_with("hello", "green") + mock_print.assert_called_once_with("colored_text", end="", file=None) + + def test_print_text_with_file_flush(self, mocker): + mock_file = MagicMock() + mock_print = mocker.patch("builtins.print") + + module.print_text("hello", file=mock_file) + + mock_print.assert_called_once_with("hello", end="", file=mock_file) + mock_file.flush.assert_called_once() + + def test_print_text_with_end_parameter(self, mock_print): + module.print_text("hello", end="\n") + mock_print.assert_called_once_with("hello", end="\n", file=None) + + +# ----------------------------- +# DifyAgentCallbackHandler Tests +# ----------------------------- + + +class TestDifyAgentCallbackHandler: + def test_init_default_color(self): + handler = module.DifyAgentCallbackHandler() + assert handler.color == "green" + assert handler.current_loop == 1 + + def test_on_tool_start_debug_enabled(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_start("tool1", {"a": 1}) + + mock_print_text.assert_called() + + def test_on_tool_start_debug_disabled(self, handler, disable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_start("tool1", {"a": 1}) + + mock_print_text.assert_not_called() + + def test_on_tool_end_debug_enabled_and_trace(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + mock_trace_manager = MagicMock() + + handler.on_tool_end( + tool_name="tool1", + tool_inputs={"a": 1}, + tool_outputs="output", + message_id="msg1", + timer=123, + trace_manager=mock_trace_manager, + ) + + assert mock_print_text.call_count >= 1 + mock_trace_manager.add_trace_task.assert_called_once() + + def test_on_tool_end_without_trace_manager(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_end( + tool_name="tool1", + tool_inputs={}, + tool_outputs="output", + ) + + assert mock_print_text.call_count >= 1 + + def test_on_tool_error_debug_enabled(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_error(Exception("error")) + + mock_print_text.assert_called_once() + + def test_on_tool_error_debug_disabled(self, handler, disable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_error(Exception("error")) + + mock_print_text.assert_not_called() + + @pytest.mark.parametrize("thought", ["thinking", ""]) + def test_on_agent_start(self, handler, enable_debug, mocker, thought): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_agent_start(thought) + + mock_print_text.assert_called() + + def test_on_agent_finish_increments_loop(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + current_loop = handler.current_loop + handler.on_agent_finish() + + assert handler.current_loop == current_loop + 1 + mock_print_text.assert_called() + + def test_on_datasource_start_debug_enabled(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_datasource_start("ds1", {"x": 1}) + + mock_print_text.assert_called_once() + + def test_ignore_agent_property(self, disable_debug, handler): + assert handler.ignore_agent is True + + def test_ignore_chat_model_property(self, disable_debug, handler): + assert handler.ignore_chat_model is True + + def test_ignore_properties_when_debug_enabled(self, enable_debug, handler): + assert handler.ignore_agent is False + assert handler.ignore_chat_model is False diff --git a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py new file mode 100644 index 0000000000..b37c4c57a1 --- /dev/null +++ b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py @@ -0,0 +1,162 @@ +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.callback_handler.index_tool_callback_handler import ( + DatasetIndexToolCallbackHandler, +) + + +@pytest.fixture +def mock_queue_manager(mocker): + return mocker.Mock() + + +@pytest.fixture +def handler(mock_queue_manager, mocker): + mocker.patch( + "core.callback_handler.index_tool_callback_handler.db", + ) + return DatasetIndexToolCallbackHandler( + queue_manager=mock_queue_manager, + app_id="app-1", + message_id="msg-1", + user_id="user-1", + invoke_from=mocker.Mock(), + ) + + +class TestOnQuery: + @pytest.mark.parametrize( + ("invoke_from", "expected_role"), + [ + (InvokeFrom.EXPLORE, "account"), + (InvokeFrom.DEBUGGER, "account"), + (InvokeFrom.WEB_APP, "end_user"), + ], + ) + def test_on_query_success_roles(self, mocker, mock_queue_manager, invoke_from, expected_role): + # Arrange + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + handler = DatasetIndexToolCallbackHandler( + queue_manager=mock_queue_manager, + app_id="app-1", + message_id="msg-1", + user_id="user-1", + invoke_from=mocker.Mock(), + ) + + handler._invoke_from = invoke_from + + # Act + handler.on_query("test query", "dataset-1") + + # Assert + mock_db.session.add.assert_called_once() + dataset_query = mock_db.session.add.call_args.args[0] + assert dataset_query.created_by_role == expected_role + mock_db.session.commit.assert_called_once() + + def test_on_query_none_values(self, mocker, mock_queue_manager): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + handler = DatasetIndexToolCallbackHandler( + queue_manager=mock_queue_manager, + app_id=None, + message_id=None, + user_id=None, + invoke_from=None, + ) + + handler.on_query(None, None) + + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + +class TestOnToolEnd: + def test_on_tool_end_no_metadata(self, handler, mocker): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + document = mocker.Mock() + document.metadata = None + + handler.on_tool_end([document]) + + mock_db.session.commit.assert_not_called() + + def test_on_tool_end_dataset_document_not_found(self, handler, mocker): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + mock_db.session.scalar.return_value = None + + document = mocker.Mock() + document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} + + handler.on_tool_end([document]) + + mock_db.session.scalar.assert_called_once() + + def test_on_tool_end_parent_child_index_with_child(self, handler, mocker): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + mock_dataset_doc = mocker.Mock() + from core.callback_handler.index_tool_callback_handler import IndexStructureType + + mock_dataset_doc.doc_form = IndexStructureType.PARENT_CHILD_INDEX + mock_dataset_doc.dataset_id = "dataset-1" + mock_dataset_doc.id = "doc-1" + + mock_child_chunk = mocker.Mock() + mock_child_chunk.segment_id = "segment-1" + + mock_db.session.scalar.side_effect = [mock_dataset_doc, mock_child_chunk] + + document = mocker.Mock() + document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} + + mock_query = mocker.Mock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + + handler.on_tool_end([document]) + + mock_query.update.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_on_tool_end_non_parent_child_index(self, handler, mocker): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + mock_dataset_doc = mocker.Mock() + mock_dataset_doc.doc_form = "OTHER" + + mock_db.session.scalar.return_value = mock_dataset_doc + + document = mocker.Mock() + document.metadata = { + "document_id": "doc-1", + "doc_id": "node-1", + "dataset_id": "dataset-1", + } + + mock_query = mocker.Mock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + + handler.on_tool_end([document]) + + mock_query.update.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_on_tool_end_empty_documents(self, handler): + handler.on_tool_end([]) + + +class TestReturnRetrieverResourceInfo: + def test_publish_called(self, handler, mock_queue_manager, mocker): + mock_event = mocker.patch("core.callback_handler.index_tool_callback_handler.QueueRetrieverResourcesEvent") + + resources = [mocker.Mock()] + + handler.return_retriever_resource_info(resources) + + mock_queue_manager.publish.assert_called_once() diff --git a/api/tests/unit_tests/core/callback_handler/test_workflow_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_workflow_tool_callback_handler.py new file mode 100644 index 0000000000..131fb006ed --- /dev/null +++ b/api/tests/unit_tests/core/callback_handler/test_workflow_tool_callback_handler.py @@ -0,0 +1,184 @@ +from unittest.mock import MagicMock, call + +import pytest + +from core.callback_handler.workflow_tool_callback_handler import ( + DifyWorkflowCallbackHandler, +) + + +class DummyToolInvokeMessage: + """Lightweight dummy to simulate ToolInvokeMessage behavior.""" + + def __init__(self, json_value: str): + self._json_value = json_value + + def model_dump_json(self): + return self._json_value + + +@pytest.fixture +def handler(): + """Fixture to create handler instance with deterministic color.""" + instance = DifyWorkflowCallbackHandler() + instance.color = "blue" + return instance + + +@pytest.fixture +def mock_print_text(mocker): + """Mock print_text to avoid real stdout printing.""" + return mocker.patch("core.callback_handler.workflow_tool_callback_handler.print_text") + + +class TestDifyWorkflowCallbackHandler: + def test_on_tool_execution_single_output_success(self, handler, mock_print_text): + # Arrange + tool_name = "test_tool" + tool_inputs = {"a": 1} + message = DummyToolInvokeMessage('{"key": "value"}') + + # Act + results = list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs=tool_inputs, + tool_outputs=[message], + ) + ) + + # Assert + assert results == [message] + assert mock_print_text.call_count == 4 + mock_print_text.assert_has_calls( + [ + call("\n[on_tool_execution]\n", color="blue"), + call("Tool: test_tool\n", color="blue"), + call( + "Outputs: " + message.model_dump_json()[:1000] + "\n", + color="blue", + ), + call("\n"), + ] + ) + + def test_on_tool_execution_multiple_outputs(self, handler, mock_print_text): + # Arrange + tool_name = "multi_tool" + outputs = [ + DummyToolInvokeMessage('{"id": 1}'), + DummyToolInvokeMessage('{"id": 2}'), + ] + + # Act + results = list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=outputs, + ) + ) + + # Assert + assert results == outputs + assert mock_print_text.call_count == 4 * len(outputs) + + def test_on_tool_execution_empty_iterable(self, handler, mock_print_text): + # Arrange + tool_name = "empty_tool" + + # Act + results = list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=[], + ) + ) + + # Assert + assert results == [] + mock_print_text.assert_not_called() + + @pytest.mark.parametrize( + ("invalid_outputs", "expected_exception"), + [ + (None, TypeError), + (123, TypeError), + ("not_iterable", AttributeError), + ], + ) + def test_on_tool_execution_invalid_outputs_type(self, handler, invalid_outputs, expected_exception): + # Arrange + tool_name = "invalid_tool" + + # Act & Assert + with pytest.raises(expected_exception): + list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=invalid_outputs, + ) + ) + + def test_on_tool_execution_long_json_truncation(self, handler, mock_print_text): + # Arrange + tool_name = "long_json_tool" + long_json = "x" * 1500 + message = DummyToolInvokeMessage(long_json) + + # Act + list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=[message], + ) + ) + + # Assert + expected_truncated = long_json[:1000] + mock_print_text.assert_any_call( + "Outputs: " + expected_truncated + "\n", + color="blue", + ) + + def test_on_tool_execution_model_dump_json_exception(self, handler, mock_print_text): + # Arrange + tool_name = "exception_tool" + bad_message = MagicMock() + bad_message.model_dump_json.side_effect = ValueError("JSON error") + + # Act & Assert + with pytest.raises(ValueError): + list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=[bad_message], + ) + ) + + # Ensure first two prints happened before failure + assert mock_print_text.call_count >= 2 + + def test_on_tool_execution_none_message_id_and_trace_manager(self, handler, mock_print_text): + # Arrange + tool_name = "optional_params_tool" + message = DummyToolInvokeMessage('{"data": "ok"}') + + # Act + results = list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=[message], + message_id=None, + timer=None, + trace_manager=None, + ) + ) + + assert results == [message] + assert mock_print_text.call_count == 4 diff --git a/api/tests/unit_tests/core/datasource/test_file_upload.py b/api/tests/unit_tests/core/datasource/test_file_upload.py index ad86190e00..63b86e64fc 100644 --- a/api/tests/unit_tests/core/datasource/test_file_upload.py +++ b/api/tests/unit_tests/core/datasource/test_file_upload.py @@ -35,7 +35,7 @@ TEST COVERAGE OVERVIEW: - Tests hash consistency and determinism 6. Invalid Filename Handling (TestInvalidFilenameHandling) - - Validates rejection of filenames with invalid characters (/, \\, :, *, ?, ", <, >, |) + - Validates rejection of filenames with path separators (/, \\) - Tests filename length truncation (max 200 characters) - Prevents path traversal attacks - Handles edge cases like empty filenames @@ -535,30 +535,23 @@ class TestInvalidFilenameHandling: @pytest.mark.parametrize( "invalid_char", - ["/", "\\", ":", "*", "?", '"', "<", ">", "|"], + ["/", "\\"], ) def test_filename_contains_invalid_characters(self, invalid_char): """Test detection of invalid characters in filename. - Security-critical test that validates rejection of dangerous filename characters. + Security-critical test that validates rejection of path separators. These characters are blocked because they: - / and \\ : Directory separators, could enable path traversal - - : : Drive letter separator on Windows, reserved character - - * and ? : Wildcards, could cause issues in file operations - - " : Quote character, could break command-line operations - - < and > : Redirection operators, command injection risk - - | : Pipe operator, command injection risk Blocking these characters prevents: - Path traversal attacks (../../etc/passwd) - - Command injection - - File system corruption - - Cross-platform compatibility issues + - ZIP entry traversal issues + - Ambiguous path handling """ # Arrange - Create filename with invalid character filename = f"test{invalid_char}file.txt" - # Define complete list of invalid characters - invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"] + invalid_chars = ["/", "\\"] # Act - Check if filename contains any invalid character has_invalid_char = any(c in filename for c in invalid_chars) @@ -570,7 +563,7 @@ class TestInvalidFilenameHandling: """Test that valid filenames pass validation.""" # Arrange filename = "valid_file-name_123.txt" - invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"] + invalid_chars = ["/", "\\"] # Act has_invalid_char = any(c in filename for c in invalid_chars) @@ -578,6 +571,16 @@ class TestInvalidFilenameHandling: # Assert assert has_invalid_char is False + @pytest.mark.parametrize("safe_char", [":", "*", "?", '"', "<", ">", "|"]) + def test_filename_allows_safe_metadata_characters(self, safe_char): + """Test that non-separator punctuation remains allowed in filenames.""" + filename = f"candidate{safe_char}resume.txt" + invalid_chars = ["/", "\\"] + + has_invalid_char = any(c in filename for c in invalid_chars) + + assert has_invalid_char is False + def test_extremely_long_filename_truncation(self): """Test handling of extremely long filenames.""" # Arrange @@ -904,7 +907,7 @@ class TestFilenameValidation: """Test that filenames with spaces are handled correctly.""" # Arrange filename = "my document with spaces.pdf" - invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"] + invalid_chars = ["/", "\\"] # Act - Check for invalid characters has_invalid = any(c in filename for c in invalid_chars) @@ -921,7 +924,7 @@ class TestFilenameValidation: "مستند.txt", # Arabic "ファイル.jpg", # Japanese ] - invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"] + invalid_chars = ["/", "\\"] # Act & Assert - Unicode should be allowed for filename in unicode_filenames: diff --git a/api/tests/unit_tests/core/entities/test_entities_agent_entities.py b/api/tests/unit_tests/core/entities/test_entities_agent_entities.py new file mode 100644 index 0000000000..2437602695 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_agent_entities.py @@ -0,0 +1,9 @@ +from core.entities.agent_entities import PlanningStrategy + + +def test_planning_strategy_values_are_stable() -> None: + # Arrange / Act / Assert + assert PlanningStrategy.ROUTER.value == "router" + assert PlanningStrategy.REACT_ROUTER.value == "react_router" + assert PlanningStrategy.REACT.value == "react" + assert PlanningStrategy.FUNCTION_CALL.value == "function_call" diff --git a/api/tests/unit_tests/core/entities/test_entities_document_task.py b/api/tests/unit_tests/core/entities/test_entities_document_task.py new file mode 100644 index 0000000000..dd550930d7 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_document_task.py @@ -0,0 +1,18 @@ +from core.entities.document_task import DocumentTask + + +def test_document_task_keeps_indexing_identifiers() -> None: + # Arrange + document_ids = ("doc-1", "doc-2") + + # Act + task = DocumentTask( + tenant_id="tenant-1", + dataset_id="dataset-1", + document_ids=document_ids, + ) + + # Assert + assert task.tenant_id == "tenant-1" + assert task.dataset_id == "dataset-1" + assert task.document_ids == document_ids diff --git a/api/tests/unit_tests/core/entities/test_entities_embedding_type.py b/api/tests/unit_tests/core/entities/test_entities_embedding_type.py new file mode 100644 index 0000000000..5a82fc4842 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_embedding_type.py @@ -0,0 +1,7 @@ +from core.entities.embedding_type import EmbeddingInputType + + +def test_embedding_input_type_values_are_stable() -> None: + # Arrange / Act / Assert + assert EmbeddingInputType.DOCUMENT.value == "document" + assert EmbeddingInputType.QUERY.value == "query" diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py new file mode 100644 index 0000000000..2e4f6d34fb --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -0,0 +1,45 @@ +from core.entities.execution_extra_content import ( + ExecutionExtraContentDomainModel, + HumanInputContent, + HumanInputFormDefinition, + HumanInputFormSubmissionData, +) +from dify_graph.nodes.human_input.entities import FormInput, UserAction +from dify_graph.nodes.human_input.enums import FormInputType +from models.execution_extra_content import ExecutionContentType + + +def test_human_input_content_defaults_and_domain_alias() -> None: + # Arrange + form_definition = HumanInputFormDefinition( + form_id="form-1", + node_id="node-1", + node_title="Human Input", + form_content="Please confirm", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="answer")], + actions=[UserAction(id="confirm", title="Confirm")], + resolved_default_values={"answer": "yes"}, + expiration_time=1_700_000_000, + ) + submission_data = HumanInputFormSubmissionData( + node_id="node-1", + node_title="Human Input", + rendered_content="Please confirm", + action_id="confirm", + action_text="Confirm", + ) + + # Act + content = HumanInputContent( + workflow_run_id="workflow-run-1", + submitted=True, + form_definition=form_definition, + form_submission_data=submission_data, + ) + + # Assert + assert form_definition.model_config.get("frozen") is True + assert content.type == ExecutionContentType.HUMAN_INPUT + assert content.form_definition is form_definition + assert content.form_submission_data is submission_data + assert ExecutionExtraContentDomainModel is HumanInputContent diff --git a/api/tests/unit_tests/core/entities/test_entities_knowledge_entities.py b/api/tests/unit_tests/core/entities/test_entities_knowledge_entities.py new file mode 100644 index 0000000000..d25f20145f --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_knowledge_entities.py @@ -0,0 +1,45 @@ +from core.entities.knowledge_entities import ( + PipelineDataset, + PipelineDocument, + PipelineGenerateResponse, +) + + +def test_pipeline_dataset_normalizes_none_description() -> None: + # Arrange / Act + dataset = PipelineDataset( + id="dataset-1", + name="Dataset", + description=None, + chunk_structure="parent-child", + ) + + # Assert + assert dataset.description == "" + + +def test_pipeline_generate_response_builds_nested_models() -> None: + # Arrange + dataset = PipelineDataset( + id="dataset-1", + name="Dataset", + description="Knowledge base", + chunk_structure="parent-child", + ) + document = PipelineDocument( + id="doc-1", + position=1, + data_source_type="file", + data_source_info={"name": "spec.pdf"}, + name="spec.pdf", + indexing_status="completed", + enabled=True, + ) + + # Act + response = PipelineGenerateResponse(batch="batch-1", dataset=dataset, documents=[document]) + + # Assert + assert response.batch == "batch-1" + assert response.dataset.id == "dataset-1" + assert response.documents[0].id == "doc-1" diff --git a/api/tests/unit_tests/core/entities/test_entities_mcp_provider.py b/api/tests/unit_tests/core/entities/test_entities_mcp_provider.py new file mode 100644 index 0000000000..5449c63b45 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_mcp_provider.py @@ -0,0 +1,450 @@ +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.entities import mcp_provider as mcp_provider_module +from core.entities.mcp_provider import ( + DEFAULT_EXPIRES_IN, + DEFAULT_TOKEN_TYPE, + MCPProviderEntity, +) +from core.mcp.types import OAuthTokens + + +def _build_mcp_provider_entity() -> MCPProviderEntity: + now = datetime(2025, 1, 1, tzinfo=UTC) + return MCPProviderEntity( + id="provider-1", + provider_id="server-1", + name="Example MCP", + tenant_id="tenant-1", + user_id="user-1", + server_url="encrypted-server-url", + headers={}, + timeout=30, + sse_read_timeout=300, + authed=False, + credentials={}, + tools=[], + icon={"en_US": "icon.png"}, + created_at=now, + updated_at=now, + ) + + +def test_from_db_model_maps_fields() -> None: + # Arrange + now = datetime(2025, 1, 1, tzinfo=UTC) + db_provider = SimpleNamespace( + id="provider-1", + server_identifier="server-1", + name="Example MCP", + tenant_id="tenant-1", + user_id="user-1", + server_url="encrypted-server-url", + headers={"Authorization": "enc"}, + timeout=15, + sse_read_timeout=120, + authed=True, + credentials={"access_token": "enc-token"}, + tool_dict=[{"name": "search"}], + icon=None, + created_at=now, + updated_at=now, + ) + + # Act + entity = MCPProviderEntity.from_db_model(db_provider) + + # Assert + assert entity.provider_id == "server-1" + assert entity.tools == [{"name": "search"}] + assert entity.icon == "" + + +def test_redirect_url_uses_console_api_url(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + entity = _build_mcp_provider_entity() + monkeypatch.setattr(mcp_provider_module.dify_config, "CONSOLE_API_URL", "https://console.example.com") + + # Act + redirect_url = entity.redirect_url + + # Assert + assert redirect_url == "https://console.example.com/console/api/mcp/oauth/callback" + + +def test_client_metadata_for_authorization_code_flow() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}): + # Act + metadata = entity.client_metadata + + # Assert + assert metadata.grant_types == ["refresh_token", "authorization_code"] + assert metadata.redirect_uris == [entity.redirect_url] + assert metadata.response_types == ["code"] + + +def test_client_metadata_for_client_credentials_flow() -> None: + # Arrange + entity = _build_mcp_provider_entity() + credentials = {"client_information": {"grant_types": ["client_credentials"]}} + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials): + # Act + metadata = entity.client_metadata + + # Assert + assert metadata.grant_types == ["refresh_token", "client_credentials"] + assert metadata.redirect_uris == [] + assert metadata.response_types == [] + + +def test_client_metadata_prefers_nested_authorization_code_grant_type() -> None: + # Arrange + entity = _build_mcp_provider_entity() + credentials = { + "grant_type": "client_credentials", + "client_information": {"grant_types": ["authorization_code"]}, + } + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials): + # Act + metadata = entity.client_metadata + + # Assert + assert metadata.grant_types == ["refresh_token", "authorization_code"] + assert metadata.redirect_uris == [entity.redirect_url] + assert metadata.response_types == ["code"] + + +def test_provider_icon_returns_icon_dict_as_is() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}}) + + # Act + icon = entity.provider_icon + + # Assert + assert icon == {"en_US": "icon.png"} + + +def test_provider_icon_uses_signed_url_for_plain_path() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"icon": "icons/mcp.png"}) + + with patch( + "core.entities.mcp_provider.file_helpers.get_signed_file_url", + return_value="https://signed.example.com/icons/mcp.png", + ) as mock_get_signed_url: + # Act + icon = entity.provider_icon + + # Assert + mock_get_signed_url.assert_called_once_with("icons/mcp.png") + assert icon == "https://signed.example.com/icons/mcp.png" + + +def test_to_api_response_without_sensitive_data_skips_auth_related_work() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}}) + + with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"): + # Act + response = entity.to_api_response(include_sensitive=False) + + # Assert + assert response["author"] == "Anonymous" + assert response["masked_headers"] == {} + assert response["is_dynamic_registration"] is True + assert "authentication" not in response + + +def test_to_api_response_with_sensitive_data_includes_masked_values() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy( + update={ + "credentials": {"client_information": {"is_dynamic_registration": False}}, + "icon": {"en_US": "icon.png"}, + } + ) + + with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"): + with patch.object(MCPProviderEntity, "masked_headers", return_value={"Authorization": "Be****"}): + with patch.object(MCPProviderEntity, "masked_credentials", return_value={"client_id": "cl****"}): + # Act + response = entity.to_api_response(user_name="Rajat", include_sensitive=True) + + # Assert + assert response["author"] == "Rajat" + assert response["masked_headers"] == {"Authorization": "Be****"} + assert response["authentication"] == {"client_id": "cl****"} + assert response["is_dynamic_registration"] is False + + +def test_retrieve_client_information_decrypts_nested_secret() -> None: + # Arrange + entity = _build_mcp_provider_entity() + credentials = {"client_information": {"client_id": "client-1", "encrypted_client_secret": "enc-secret"}} + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials): + with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="plain-secret") as mock_decrypt: + # Act + client_info = entity.retrieve_client_information() + + # Assert + assert client_info is not None + assert client_info.client_id == "client-1" + assert client_info.client_secret == "plain-secret" + mock_decrypt.assert_called_once_with("tenant-1", "enc-secret") + + +def test_retrieve_client_information_returns_none_for_missing_data() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}): + # Act + result_empty = entity.retrieve_client_information() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}): + # Act + result_invalid = entity.retrieve_client_information() + + # Assert + assert result_empty is None + assert result_invalid is None + + +def test_masked_server_url_hides_path_segments() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object( + MCPProviderEntity, + "decrypt_server_url", + return_value="https://api.example.com/v1/mcp?query=1", + ): + # Act + masked_url = entity.masked_server_url() + + # Assert + assert masked_url == "https://api.example.com/******?query=1" + + +def test_mask_value_covers_short_and_long_values() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + # Act + short_masked = entity._mask_value("short") + long_masked = entity._mask_value("abcdefghijkl") + + # Assert + assert short_masked == "*****" + assert long_masked == "ab********kl" + + +def test_masked_headers_masks_all_decrypted_header_values() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "abcdefgh"}): + # Act + masked = entity.masked_headers() + + # Assert + assert masked == {"Authorization": "ab****gh"} + + +def test_masked_credentials_handles_nested_secret_fields() -> None: + # Arrange + entity = _build_mcp_provider_entity() + credentials = { + "client_information": { + "client_id": "client-id", + "encrypted_client_secret": "encrypted-value", + "client_secret": "plain-secret", + } + } + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials): + with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="decrypted-secret"): + # Act + masked = entity.masked_credentials() + + # Assert + assert masked["client_id"] == "cl*****id" + assert masked["client_secret"] == "pl********et" + + +def test_masked_credentials_returns_empty_for_missing_client_information() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}): + # Act + masked_empty = entity.masked_credentials() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}): + # Act + masked_invalid = entity.masked_credentials() + + # Assert + assert masked_empty == {} + assert masked_invalid == {} + + +def test_retrieve_tokens_returns_defaults_when_optional_fields_missing() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}}) + + with patch.object( + MCPProviderEntity, + "decrypt_credentials", + return_value={"access_token": "token", "expires_in": "", "refresh_token": "refresh"}, + ): + # Act + tokens = entity.retrieve_tokens() + + # Assert + assert isinstance(tokens, OAuthTokens) + assert tokens.access_token == "token" + assert tokens.token_type == DEFAULT_TOKEN_TYPE + assert tokens.expires_in == DEFAULT_EXPIRES_IN + assert tokens.refresh_token == "refresh" + + +def test_retrieve_tokens_returns_none_when_access_token_missing() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}}) + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"access_token": ""}) as mock_decrypt: + # Act + tokens = entity.retrieve_tokens() + + # Assert + mock_decrypt.assert_called_once() + assert tokens is None + + +def test_decrypt_server_url_delegates_to_encrypter() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="https://api.example.com") as mock: + # Act + decrypted = entity.decrypt_server_url() + + # Assert + mock.assert_called_once_with("tenant-1", "encrypted-server-url") + assert decrypted == "https://api.example.com" + + +def test_decrypt_authentication_injects_authorization_for_oauth() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"authed": True, "headers": {}}) + + with patch.object(MCPProviderEntity, "decrypt_headers", return_value={}): + with patch.object( + MCPProviderEntity, + "retrieve_tokens", + return_value=OAuthTokens(access_token="abc123", token_type="bearer"), + ): + # Act + headers = entity.decrypt_authentication() + + # Assert + assert headers["Authorization"] == "Bearer abc123" + + +def test_decrypt_authentication_does_not_overwrite_existing_headers() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy( + update={"authed": True, "headers": {"Authorization": "encrypted-header"}} + ) + + with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "existing"}): + with patch.object( + MCPProviderEntity, + "retrieve_tokens", + return_value=OAuthTokens(access_token="abc", token_type="bearer"), + ) as mock_tokens: + # Act + headers = entity.decrypt_authentication() + + # Assert + mock_tokens.assert_not_called() + assert headers == {"Authorization": "existing"} + + +def test_decrypt_dict_returns_empty_for_empty_input() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + # Act + decrypted = entity._decrypt_dict({}) + + # Assert + assert decrypted == {} + + +def test_decrypt_dict_returns_original_data_when_no_encrypted_fields() -> None: + # Arrange + entity = _build_mcp_provider_entity() + input_data = {"nested": {"k": "v"}, "count": 2, "empty": ""} + + # Act + result = entity._decrypt_dict(input_data) + + # Assert + assert result is input_data + + +def test_decrypt_dict_only_decrypts_top_level_string_values() -> None: + # Arrange + entity = _build_mcp_provider_entity() + decryptor = Mock() + decryptor.decrypt.return_value = {"api_key": "plain-key"} + + def _fake_create_provider_encrypter(*, tenant_id: str, config: list, cache): + assert tenant_id == "tenant-1" + assert any(item.name == "api_key" for item in config) + return decryptor, None + + with patch("core.tools.utils.encryption.create_provider_encrypter", side_effect=_fake_create_provider_encrypter): + # Act + result = entity._decrypt_dict( + { + "api_key": "encrypted-key", + "nested": {"client_id": "unchanged"}, + "empty": "", + "count": 2, + } + ) + + # Assert + decryptor.decrypt.assert_called_once_with({"api_key": "encrypted-key"}) + assert result["api_key"] == "plain-key" + assert result["nested"] == {"client_id": "unchanged"} + assert result["count"] == 2 + + +def test_decrypt_headers_and_credentials_delegate_to_decrypt_dict() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "_decrypt_dict", side_effect=[{"h": "v"}, {"c": "v"}]) as mock: + # Act + headers = entity.decrypt_headers() + credentials = entity.decrypt_credentials() + + # Assert + assert mock.call_count == 2 + assert headers == {"h": "v"} + assert credentials == {"c": "v"} diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py new file mode 100644 index 0000000000..7a3d5e84ed --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -0,0 +1,92 @@ +"""Unit tests for model entity behavior and invariants. + +Covers DefaultModelEntity, DefaultModelProviderEntity, ModelStatus, +ProviderModelWithStatusEntity, and SimpleModelProviderEntity. Assumes i18n +labels are provided via I18nObject, model metadata aligns with FetchFrom and +ModelType expectations, and ProviderEntity/ConfigurateMethod interactions +drive provider mapping behavior. +""" + +import pytest + +from core.entities.model_entities import ( + DefaultModelEntity, + DefaultModelProviderEntity, + ModelStatus, + ProviderModelWithStatusEntity, + SimpleModelProviderEntity, +) +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity + + +def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: + return ProviderModelWithStatusEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + status=status, + ) + + +def test_simple_model_provider_entity_maps_from_provider_entity() -> None: + # Arrange + provider_entity = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + + # Act + simple_provider = SimpleModelProviderEntity(provider_entity) + + # Assert + assert simple_provider.provider == "openai" + assert simple_provider.label.en_US == "OpenAI" + assert simple_provider.supported_model_types == [ModelType.LLM] + + +def test_provider_model_with_status_raises_for_known_error_statuses() -> None: + # Arrange + expectations = { + ModelStatus.NO_CONFIGURE: "Model is not configured", + ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded", + ModelStatus.NO_PERMISSION: "No permission to use this model", + ModelStatus.DISABLED: "Model is disabled", + } + + for status, message in expectations.items(): + # Act / Assert + with pytest.raises(ValueError, match=message): + _build_model_with_status(status).raise_for_status() + + +def test_provider_model_with_status_allows_active_and_credential_removed() -> None: + # Arrange + active_model = _build_model_with_status(ModelStatus.ACTIVE) + removed_model = _build_model_with_status(ModelStatus.CREDENTIAL_REMOVED) + + # Act / Assert + active_model.raise_for_status() + removed_model.raise_for_status() + + +def test_default_model_entity_accepts_model_field_name() -> None: + # Arrange / Act + default_model = DefaultModelEntity( + model="gpt-4o-mini", + model_type=ModelType.LLM, + provider=DefaultModelProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + ), + ) + + # Assert + assert default_model.model == "gpt-4o-mini" + assert default_model.provider.provider == "openai" diff --git a/api/tests/unit_tests/core/entities/test_entities_parameter_entities.py b/api/tests/unit_tests/core/entities/test_entities_parameter_entities.py new file mode 100644 index 0000000000..20b7bf2a9f --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_parameter_entities.py @@ -0,0 +1,22 @@ +from core.entities.parameter_entities import ( + AppSelectorScope, + CommonParameterType, + ModelSelectorScope, + ToolSelectorScope, +) + + +def test_common_parameter_type_values_are_stable() -> None: + # Arrange / Act / Assert + assert CommonParameterType.SECRET_INPUT.value == "secret-input" + assert CommonParameterType.MODEL_SELECTOR.value == "model-selector" + assert CommonParameterType.DYNAMIC_SELECT.value == "dynamic-select" + assert CommonParameterType.ARRAY.value == "array" + assert CommonParameterType.OBJECT.value == "object" + + +def test_selector_scope_values_are_stable() -> None: + # Arrange / Act / Assert + assert AppSelectorScope.WORKFLOW.value == "workflow" + assert ModelSelectorScope.TEXT_EMBEDDING.value == "text-embedding" + assert ToolSelectorScope.BUILTIN.value == "builtin" diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py new file mode 100644 index 0000000000..82f98d07a3 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -0,0 +1,1850 @@ +from __future__ import annotations + +from contextlib import contextmanager +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from constants import HIDDEN_VALUE +from core.entities.model_entities import ModelStatus +from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations +from core.entities.provider_entities import ( + CredentialConfiguration, + CustomConfiguration, + CustomModelConfiguration, + CustomProviderConfiguration, + ModelLoadBalancingConfiguration, + ModelSettings, + ProviderQuotaType, + QuotaConfiguration, + QuotaUnit, + RestrictModel, + SystemConfiguration, + SystemConfigurationStatus, +) +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from dify_graph.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) +from models.provider import ProviderType +from models.provider_ids import ModelProviderID + +_UNSET = object() + + +def _build_provider_configuration(*, provider_name: str = "openai") -> ProviderConfiguration: + provider_entity = ProviderEntity( + provider=provider_name, + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + system_configuration = SystemConfiguration( + enabled=True, + credentials={"api_key": "test-key"}, + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=1_000, + quota_used=0, + is_valid=True, + restrict_models=[], + ) + ], + ) + with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): + return ProviderConfiguration( + tenant_id="tenant-1", + provider=provider_entity, + preferred_provider_type=ProviderType.SYSTEM, + using_provider_type=ProviderType.SYSTEM, + system_configuration=system_configuration, + custom_configuration=CustomConfiguration(provider=None, models=[]), + model_settings=[], + ) + + +def _build_ai_model(name: str, *, model_type: ModelType = ModelType.LLM) -> AIModelEntity: + return AIModelEntity( + model=name, + label=I18nObject(en_US=name), + model_type=model_type, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + ) + + +def _exec_result( + *, + scalar_one_or_none: Any = _UNSET, + scalar: Any = _UNSET, + scalars_all: Any = _UNSET, + scalars_first: Any = _UNSET, +) -> Mock: + result = Mock() + if scalar_one_or_none is not _UNSET: + result.scalar_one_or_none.return_value = scalar_one_or_none + if scalar is not _UNSET: + result.scalar.return_value = scalar + if scalars_all is not _UNSET or scalars_first is not _UNSET: + scalars = Mock() + if scalars_all is not _UNSET: + scalars.all.return_value = scalars_all + if scalars_first is not _UNSET: + scalars.first.return_value = scalars_first + result.scalars.return_value = scalars + return result + + +@contextmanager +def _patched_session(session: Mock): + with patch("core.entities.provider_configuration.db") as mock_db: + mock_db.engine = Mock() + with patch("core.entities.provider_configuration.Session") as mock_session_cls: + mock_session_cls.return_value.__enter__.return_value = session + yield mock_session_cls + + +def _build_secret_provider_schema() -> ProviderCredentialSchema: + return ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="openai_api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ) + ] + ) + + +def _build_secret_model_schema() -> ModelCredentialSchema: + return ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model")), + credential_form_schemas=[ + CredentialFormSchema( + variable="openai_api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ) + ], + ) + + +def test_extract_secret_variables_returns_only_secret_inputs() -> None: + configuration = _build_provider_configuration() + credential_form_schemas = [ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ), + CredentialFormSchema( + variable="endpoint", + label=I18nObject(en_US="Endpoint"), + type=FormType.TEXT_INPUT, + ), + ] + + secret_variables = configuration.extract_secret_variables(credential_form_schemas) + assert secret_variables == ["api_key"] + + +def test_obfuscated_credentials_masks_only_secret_fields() -> None: + configuration = _build_provider_configuration() + credential_form_schemas = [ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ), + CredentialFormSchema( + variable="endpoint", + label=I18nObject(en_US="Endpoint"), + type=FormType.TEXT_INPUT, + ), + ] + + with patch( + "core.entities.provider_configuration.encrypter.obfuscated_token", + side_effect=lambda value: f"masked-{value[-2:]}", + ): + obfuscated = configuration.obfuscated_credentials( + credentials={"api_key": "sk-test-1234", "endpoint": "https://api.example.com"}, + credential_form_schemas=credential_form_schemas, + ) + + assert obfuscated["api_key"] == "masked-34" + assert obfuscated["endpoint"] == "https://api.example.com" + + +def test_provider_configurations_behave_like_keyed_container() -> None: + configuration = _build_provider_configuration() + provider_key = str(ModelProviderID("openai")) + configurations = ProviderConfigurations(tenant_id="tenant-1") + + configurations[provider_key] = configuration + + assert "openai" in configurations + assert configurations["openai"] is configuration + assert configurations.get("openai") is configuration + assert configurations.to_list() == [configuration] + assert list(configurations) == [(provider_key, configuration)] + + +def test_provider_configurations_get_models_forwards_filters() -> None: + configuration = _build_provider_configuration() + provider_key = str(ModelProviderID("openai")) + configurations = ProviderConfigurations(tenant_id="tenant-1") + configurations[provider_key] = configuration + expected_model = Mock() + + with patch.object(ProviderConfiguration, "get_provider_models", return_value=[expected_model]) as mock_get: + models = configurations.get_models(provider="openai", model_type=ModelType.LLM, only_active=True) + + mock_get.assert_called_once_with(ModelType.LLM, True) + assert models == [expected_model] + + +def test_provider_configurations_get_models_skips_non_matching_provider_filter() -> None: + configuration = _build_provider_configuration() + provider_key = str(ModelProviderID("openai")) + configurations = ProviderConfigurations(tenant_id="tenant-1") + configurations[provider_key] = configuration + + with patch.object(ProviderConfiguration, "get_provider_models", return_value=[Mock()]) as mock_get: + models = configurations.get_models(provider="anthropic", model_type=ModelType.LLM, only_active=True) + + assert models == [] + mock_get.assert_not_called() + + +def test_get_current_credentials_custom_provider_checks_current_credential() -> None: + configuration = _build_provider_configuration() + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.provider = CustomProviderConfiguration( + credentials={"api_key": "provider-key"}, + current_credential_id="credential-1", + current_credential_name="Primary", + available_credentials=[], + ) + + with patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check: + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + + assert credentials == {"api_key": "provider-key"} + assert mock_check.call_count == 1 + assert mock_check.call_args.kwargs["credential_id"] == "credential-1" + assert mock_check.call_args.kwargs["provider"] == "openai" + + +def test_get_current_credentials_custom_provider_checks_all_available_credentials() -> None: + configuration = _build_provider_configuration() + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.provider = CustomProviderConfiguration( + credentials={"api_key": "provider-key"}, + available_credentials=[ + CredentialConfiguration(credential_id="cred-1", credential_name="First"), + CredentialConfiguration(credential_id="cred-2", credential_name="Second"), + ], + ) + + with patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check: + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + + assert credentials == {"api_key": "provider-key"} + assert [c.kwargs["credential_id"] for c in mock_check.call_args_list] == ["cred-1", "cred-2"] + assert all(c.kwargs["provider"] == "openai" for c in mock_check.call_args_list) + + +def test_get_system_configuration_status_returns_none_when_current_quota_missing() -> None: + configuration = _build_provider_configuration() + configuration.system_configuration.current_quota_type = ProviderQuotaType.FREE + + status = configuration.get_system_configuration_status() + assert status is None + + +def test_get_provider_names_supports_legacy_and_full_plugin_id() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider = "langgenius/openai/openai" + + provider_names = configuration._get_provider_names() + assert provider_names == ["langgenius/openai/openai", "openai"] + + +def test_generate_next_api_key_name_uses_highest_numeric_suffix() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalars.return_value.all.return_value = [ + SimpleNamespace(credential_name="API KEY 9"), + SimpleNamespace(credential_name="legacy"), + SimpleNamespace(credential_name=" API KEY 2 "), + ] + + name = configuration._generate_next_api_key_name(session=session, query_factory=lambda: Mock()) + assert name == "API KEY 10" + + +def test_generate_next_api_key_name_falls_back_to_default_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + + def _raise_query_error(): + raise RuntimeError("boom") + + name = configuration._generate_next_api_key_name(session=session, query_factory=_raise_query_error) + assert name == "API KEY 1" + + +def test_generate_provider_and_custom_model_names_delegate_to_shared_generator() -> None: + configuration = _build_provider_configuration() + + with patch.object(configuration, "_generate_next_api_key_name", return_value="API KEY 7") as mock_generator: + provider_name = configuration._generate_provider_credential_name(session=Mock()) + custom_model_name = configuration._generate_custom_model_credential_name( + model="gpt-4o", + model_type=ModelType.LLM, + session=Mock(), + ) + + assert provider_name == "API KEY 7" + assert custom_model_name == "API KEY 7" + assert mock_generator.call_count == 2 + + +def test_get_provider_credential_uses_specific_lookup_when_id_provided() -> None: + configuration = _build_provider_configuration() + + with patch.object(configuration, "_get_specific_provider_credential", return_value={"api_key": "***"}) as mock_get: + credential = configuration.get_provider_credential("credential-1") + + assert credential == {"api_key": "***"} + mock_get.assert_called_once_with("credential-1") + + +def test_validate_provider_credentials_handles_hidden_secret_value() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="openai_api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ) + ] + ) + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key") + mock_factory = Mock() + mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): + with patch( + "core.entities.provider_configuration.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc::{value}", + ): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, + credential_id="credential-1", + session=session, + ) + + assert validated["openai_api_key"] == "enc::restored-key" + assert validated["region"] == "us" + mock_factory.provider_credentials_validate.assert_called_once_with( + provider="openai", + credentials={"openai_api_key": "restored-key", "region": "us"}, + ) + + +def test_validate_provider_credentials_opens_session_when_not_passed() -> None: + configuration = _build_provider_configuration() + mock_session = Mock() + mock_factory = Mock() + mock_factory.provider_credentials_validate.return_value = {"region": "us"} + + with patch("core.entities.provider_configuration.Session") as mock_session_cls: + with patch("core.entities.provider_configuration.db") as mock_db: + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = mock_session + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + validated = configuration.validate_provider_credentials(credentials={"region": "us"}) + + assert validated == {"region": "us"} + mock_session_cls.assert_called_once() + + +def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None: + configuration = _build_provider_configuration() + + with patch("core.entities.provider_configuration.Session") as mock_session_cls: + configuration.switch_preferred_provider_type(ProviderType.SYSTEM) + mock_session_cls.assert_not_called() + + configuration.preferred_provider_type = ProviderType.CUSTOM + configuration.system_configuration.enabled = False + with patch("core.entities.provider_configuration.Session") as mock_session_cls: + configuration.switch_preferred_provider_type(ProviderType.SYSTEM) + mock_session_cls.assert_not_called() + + +def test_switch_preferred_provider_type_updates_existing_record_with_session() -> None: + configuration = _build_provider_configuration() + configuration.preferred_provider_type = ProviderType.CUSTOM + session = Mock() + existing_record = SimpleNamespace(preferred_provider_type="custom") + session.execute.return_value.scalars.return_value.first.return_value = existing_record + + configuration.switch_preferred_provider_type(ProviderType.SYSTEM, session=session) + + assert existing_record.preferred_provider_type == ProviderType.SYSTEM.value + session.commit.assert_called_once() + + +def test_switch_preferred_provider_type_creates_record_when_missing() -> None: + configuration = _build_provider_configuration() + configuration.preferred_provider_type = ProviderType.SYSTEM + session = Mock() + session.execute.return_value.scalars.return_value.first.return_value = None + + configuration.switch_preferred_provider_type(ProviderType.CUSTOM, session=session) + + assert session.add.call_count == 1 + session.commit.assert_called_once() + + +def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: + configuration = _build_provider_configuration() + mock_factory = Mock() + mock_model_type_instance = Mock() + mock_schema = _build_ai_model("gpt-4o") + mock_factory.get_model_type_instance.return_value = mock_model_type_instance + mock_factory.get_model_schema.return_value = mock_schema + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + model_type_instance = configuration.get_model_type_instance(ModelType.LLM) + model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) + + assert model_type_instance is mock_model_type_instance + assert model_schema is mock_schema + mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM) + mock_factory.get_model_schema.assert_called_once_with( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"api_key": "x"}, + ) + + +def test_get_provider_model_returns_none_when_model_not_found() -> None: + configuration = _build_provider_configuration() + fake_model = SimpleNamespace(model="other-model") + + with patch.object(ProviderConfiguration, "get_provider_models", return_value=[fake_model]): + selected = configuration.get_provider_model(ModelType.LLM, "gpt-4o") + + assert selected is None + + +def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> None: + configuration = _build_provider_configuration() + configuration.provider.position = {"llm": ["b-model", "a-model"]} + configuration.model_settings = [ + ModelSettings(model="a-model", model_type=ModelType.LLM, enabled=False, load_balancing_configs=[]) + ] + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[_build_ai_model("a-model"), _build_ai_model("b-model"), _build_ai_model("a-model")], + ) + mock_factory = Mock() + mock_factory.get_provider_schema.return_value = provider_schema + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False) + active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True) + + assert [model.model for model in all_models] == ["b-model", "a-model"] + assert [model.status for model in all_models] == [ModelStatus.ACTIVE, ModelStatus.DISABLED] + assert [model.model for model in active_models] == ["b-model"] + + +def test_get_custom_provider_models_sets_status_for_removed_credentials_and_invalid_lb_configs() -> None: + configuration = _build_provider_configuration() + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"}) + configuration.custom_configuration.models = [ + CustomModelConfiguration( + model="custom-model", + model_type=ModelType.LLM, + credentials=None, + available_model_credentials=[CredentialConfiguration(credential_id="c-1", credential_name="first")], + ) + ] + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[_build_ai_model("base-model")], + ) + model_setting_map = { + ModelType.LLM: { + "base-model": ModelSettings( + model="base-model", + model_type=ModelType.LLM, + enabled=True, + load_balancing_enabled=True, + load_balancing_configs=[ + ModelLoadBalancingConfiguration( + id="lb-base", + name="LB Base", + credentials={}, + credential_source_type="provider", + ) + ], + ), + "custom-model": ModelSettings( + model="custom-model", + model_type=ModelType.LLM, + enabled=True, + load_balancing_enabled=True, + load_balancing_configs=[ + ModelLoadBalancingConfiguration( + id="lb-custom", + name="LB Custom", + credentials={}, + credential_source_type="custom_model", + ) + ], + ), + } + } + + with patch.object(ProviderConfiguration, "get_model_schema", return_value=_build_ai_model("custom-model")): + models = configuration._get_custom_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map=model_setting_map, + ) + + status_map = {model.model: model.status for model in models} + invalid_lb_map = {model.model: model.has_invalid_load_balancing_configs for model in models} + assert status_map["base-model"] == ModelStatus.ACTIVE + assert status_map["custom-model"] == ModelStatus.CREDENTIAL_REMOVED + assert invalid_lb_map["base-model"] is True + assert invalid_lb_map["custom-model"] is True + + +def test_validator_adds_predefined_model_for_customizable_provider_with_restrictions() -> None: + provider = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL], + ) + system_configuration = SystemConfiguration( + enabled=True, + credentials={"api_key": "test-key"}, + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="restricted", base_model_name="base-model", model_type=ModelType.LLM) + ], + ) + ], + ) + with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): + configuration = ProviderConfiguration( + tenant_id="tenant-1", + provider=provider, + preferred_provider_type=ProviderType.SYSTEM, + using_provider_type=ProviderType.SYSTEM, + system_configuration=system_configuration, + custom_configuration=CustomConfiguration(provider=None, models=[]), + model_settings=[], + ) + + assert ConfigurateMethod.PREDEFINED_MODEL in configuration.provider.configurate_methods + + +def test_get_current_credentials_system_handles_disable_and_restricted_base_model() -> None: + configuration = _build_provider_configuration() + configuration.model_settings = [ + ModelSettings(model="gpt-4o", model_type=ModelType.LLM, enabled=False, load_balancing_configs=[]) + ] + + with pytest.raises(ValueError, match="Model gpt-4o is disabled"): + configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + + configuration.model_settings = [] + configuration.system_configuration.quota_configurations[0].restrict_models = [ + RestrictModel(model="gpt-4o", base_model_name="base-model", model_type=ModelType.LLM) + ] + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + assert credentials["base_model_name"] == "base-model" + + +def test_get_current_credentials_prefers_model_specific_custom_credentials() -> None: + configuration = _build_provider_configuration() + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.models = [ + CustomModelConfiguration( + model="gpt-4o", + model_type=ModelType.LLM, + credentials={"api_key": "model-key"}, + ) + ] + configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"}) + + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + assert credentials == {"api_key": "model-key"} + + +def test_get_system_configuration_status_falsey_quota_returns_unsupported() -> None: + class _FalseyQuota: + quota_type = ProviderQuotaType.TRIAL + is_valid = True + + def __bool__(self) -> bool: + return False + + configuration = _build_provider_configuration() + configuration.system_configuration.quota_configurations = [_FalseyQuota()] # type: ignore[list-item] + assert configuration.get_system_configuration_status() == SystemConfigurationStatus.UNSUPPORTED + + +def test_get_provider_credential_default_uses_custom_provider_credentials() -> None: + configuration = _build_provider_configuration() + configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"}) + obfuscated = configuration.get_provider_credential() + assert obfuscated == {"api_key": "provider-key"} + + +def test_custom_configuration_availability_and_provider_record_helpers() -> None: + configuration = _build_provider_configuration() + assert not configuration.is_custom_configuration_available() + + configuration.custom_configuration.provider = CustomProviderConfiguration( + credentials={"api_key": "provider-key"}, + available_credentials=[CredentialConfiguration(credential_id="cred-1", credential_name="Main")], + ) + assert configuration.is_custom_configuration_available() + + configuration.custom_configuration.provider = None + configuration.custom_configuration.models = [ + CustomModelConfiguration(model="gpt-4o", model_type=ModelType.LLM, credentials={"api_key": "model-key"}) + ] + assert configuration.is_custom_configuration_available() + + session = Mock() + provider_record = SimpleNamespace(id="provider-1") + session.execute.return_value.scalar_one_or_none.return_value = provider_record + assert configuration._get_provider_record(session) is provider_record + + session.execute.return_value.scalar_one_or_none.return_value = None + assert configuration._get_provider_record(session) is None + + +def test_check_provider_credential_name_exists_and_model_setting_lookup() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = "existing-id" + assert configuration._check_provider_credential_name_exists("Main", session) + + session.execute.return_value.scalar_one_or_none.return_value = None + assert not configuration._check_provider_credential_name_exists("Main", session, exclude_id="cred-2") + + setting = SimpleNamespace(id="setting-1") + session.execute.return_value.scalars.return_value.first.return_value = setting + assert configuration._get_provider_model_setting(ModelType.LLM, "gpt-4o", session) is setting + + +def test_validate_provider_credentials_handles_invalid_original_json() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = _build_secret_provider_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_factory = Mock() + mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + session=session, + ) + + assert validated == {"openai_api_key": "enc-key"} + + +def test_generate_next_api_key_name_returns_default_when_no_records() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalars.return_value.all.return_value = [] + + name = configuration._generate_next_api_key_name(session=session, query_factory=lambda: Mock()) + assert name == "API KEY 1" + + +def test_create_provider_credential_creates_provider_record_when_missing() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.flush.side_effect = lambda: None + + with _patched_session(session): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with patch.object( + ProviderConfiguration, + "_generate_provider_credential_name", + return_value="API KEY 2", + ): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type") as mock_switch: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.create_provider_credential({"api_key": "raw"}, None) + + assert session.add.call_count == 2 + session.commit.assert_called_once() + mock_cache.return_value.delete.assert_called_once() + mock_switch.assert_called_once_with(provider_type=ProviderType.CUSTOM, session=session) + + +def test_create_provider_credential_marks_existing_provider_as_valid() -> None: + configuration = _build_provider_configuration() + session = Mock() + provider_record = SimpleNamespace(is_valid=False) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + configuration.create_provider_credential({"api_key": "raw"}, "Main") + + assert provider_record.is_valid is True + session.commit.assert_called_once() + + +def test_create_provider_credential_raises_when_duplicate_name_exists() -> None: + configuration = _build_provider_configuration() + session = Mock() + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=True): + with pytest.raises(ValueError, match="already exists"): + configuration.create_provider_credential({"api_key": "raw"}, "Main") + + +def test_update_provider_credential_success_updates_and_invalidates_cache() -> None: + configuration = _build_provider_configuration() + session = Mock() + credential_record = SimpleNamespace(id="cred-1", encrypted_config="{}", credential_name="Old", updated_at=None) + provider_record = SimpleNamespace(id="provider-1", credential_id="cred-1") + session.execute.return_value.scalar_one_or_none.return_value = credential_record + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch.object( + ProviderConfiguration, + "_update_load_balancing_configs_with_credential", + ) as mock_lb: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.update_provider_credential( + credentials={"api_key": "raw"}, + credential_id="cred-1", + credential_name="New Name", + ) + + assert credential_record.credential_name == "New Name" + session.commit.assert_called_once() + mock_cache.return_value.delete.assert_called_once() + mock_lb.assert_called_once() + + +def test_update_provider_credential_raises_when_record_not_found() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + + with _patched_session(session): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.update_provider_credential({"api_key": "raw"}, "cred-1", None) + + +def test_update_load_balancing_configs_updates_all_matching_configs() -> None: + configuration = _build_provider_configuration() + session = Mock() + lb_config = SimpleNamespace(id="lb-1", encrypted_config="old", name="old", updated_at=None) + session.execute.return_value.scalars.return_value.all.return_value = [lb_config] + credential_record = SimpleNamespace(encrypted_config='{"api_key":"enc"}', credential_name="API KEY 3") + + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration._update_load_balancing_configs_with_credential( + credential_id="cred-1", + credential_record=credential_record, + credential_source="provider", + session=session, + ) + + assert lb_config.encrypted_config == '{"api_key":"enc"}' + assert lb_config.name == "API KEY 3" + mock_cache.return_value.delete.assert_called_once() + session.commit.assert_called_once() + + +def test_update_load_balancing_configs_returns_when_no_matching_configs() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalars.return_value.all.return_value = [] + + configuration._update_load_balancing_configs_with_credential( + credential_id="cred-1", + credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"), + credential_source="provider", + session=session, + ) + + session.commit.assert_not_called() + + +def test_delete_provider_credential_removes_provider_record_when_last_credential() -> None: + configuration = _build_provider_configuration() + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_record = SimpleNamespace(id="provider-1", credential_id="cred-1", updated_at=None) + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=credential_record), + _exec_result(scalars_all=[]), + _exec_result(scalar=1), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type") as mock_switch: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.delete_provider_credential("cred-1") + + assert any(call.args and call.args[0] is provider_record for call in session.delete.call_args_list) + mock_cache.return_value.delete.assert_called_once() + mock_switch.assert_called_once_with(provider_type=ProviderType.SYSTEM, session=session) + + +def test_delete_provider_credential_raises_when_not_found() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.delete_provider_credential("cred-1") + + +def test_delete_provider_credential_unsets_active_credential_when_more_available() -> None: + configuration = _build_provider_configuration() + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + lb_config = SimpleNamespace(id="lb-1") + provider_record = SimpleNamespace(id="provider-1", credential_id="cred-1", updated_at=None) + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=credential_record), + _exec_result(scalars_all=[lb_config]), + _exec_result(scalar=2), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type") as mock_switch: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.delete_provider_credential("cred-1") + + assert provider_record.credential_id is None + assert mock_cache.return_value.delete.call_count == 2 + mock_switch.assert_called_once_with(provider_type=ProviderType.SYSTEM, session=session) + + +def test_switch_active_provider_credential_success_and_failures() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.switch_active_provider_credential("cred-1") + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(id="cred-1") + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(ValueError, match="Provider record not found"): + configuration.switch_active_provider_credential("cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_record = SimpleNamespace(id="provider-1", credential_id=None, updated_at=None) + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type") as mock_switch: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.switch_active_provider_credential("cred-1") + + assert provider_record.credential_id == "cred-1" + mock_cache.return_value.delete.assert_called_once() + mock_switch.assert_called_once_with(ProviderType.CUSTOM, session=session) + + +def test_get_custom_model_record_supports_plugin_id_alias() -> None: + configuration = _build_provider_configuration(provider_name="langgenius/openai/openai") + session = Mock() + custom_model_record = SimpleNamespace(id="model-1") + session.execute.return_value.scalar_one_or_none.return_value = custom_model_record + + result = configuration._get_custom_model_record(ModelType.LLM, "gpt-4o", session) + assert result is custom_model_record + + +def test_get_specific_custom_model_credential_success_and_not_found() -> None: + configuration = _build_provider_configuration() + configuration.provider.model_credential_schema = _build_secret_model_schema() + session = Mock() + record = SimpleNamespace(id="cred-1", credential_name="Main", encrypted_config='{"openai_api_key":"enc"}') + session.execute.return_value.scalar_one_or_none.return_value = record + + with _patched_session(session): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): + with patch.object(ProviderConfiguration, "obfuscated_credentials", return_value={"openai_api_key": "***"}): + response = configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + assert response["current_credential_id"] == "cred-1" + assert response["credentials"] == {"openai_api_key": "***"} + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + with _patched_session(session): + with pytest.raises(ValueError, match="Credential with id cred-1 not found"): + configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + id="cred-1", + credential_name="Main", + encrypted_config="{invalid-json", + ) + with _patched_session(session): + invalid_json = configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + assert invalid_json["credentials"] == {} + + +def test_check_custom_model_credential_name_exists_respects_exclusion() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(id="cred-1") + assert configuration._check_custom_model_credential_name_exists( + ModelType.LLM, "gpt-4o", "Main", session, exclude_id="other-id" + ) + + session.execute.return_value.scalar_one_or_none.return_value = None + assert not configuration._check_custom_model_credential_name_exists(ModelType.LLM, "gpt-4o", "Main", session) + + +def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback() -> None: + configuration = _build_provider_configuration() + with patch.object( + ProviderConfiguration, + "_get_specific_custom_model_credential", + return_value={"current_credential_id": "cred-1"}, + ) as mock_specific: + result = configuration.get_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + assert result == {"current_credential_id": "cred-1"} + mock_specific.assert_called_once() + + configuration.provider.model_credential_schema = _build_secret_model_schema() + configuration.custom_configuration.models = [ + CustomModelConfiguration( + model="gpt-4o", + model_type=ModelType.LLM, + credentials={"openai_api_key": "raw"}, + current_credential_id="cred-1", + current_credential_name="Main", + ) + ] + with patch.object(ProviderConfiguration, "obfuscated_credentials", return_value={"openai_api_key": "***"}): + fallback = configuration.get_custom_model_credential(ModelType.LLM, "gpt-4o", None) + assert fallback == { + "current_credential_id": "cred-1", + "current_credential_name": "Main", + "credentials": {"openai_api_key": "***"}, + } + + configuration.custom_configuration.models = [] + assert configuration.get_custom_model_credential(ModelType.LLM, "gpt-4o", None) is None + + +def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None: + configuration = _build_provider_configuration() + configuration.provider.model_credential_schema = _build_secret_model_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config='{"openai_api_key":"enc"}' + ) + mock_factory = Mock() + mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + session=session, + ) + assert validated == {"openai_api_key": "enc-new"} + + session = Mock() + mock_factory = Mock() + mock_factory.model_credentials_validate.return_value = {"region": "us"} + with _patched_session(session): + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"region": "us"}, + ) + assert validated == {"region": "us"} + + +def test_create_update_delete_custom_model_credential_flow() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.flush.side_effect = lambda: None + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + credential_record = SimpleNamespace(id="cred-1", encrypted_config="{}", credential_name="Old", updated_at=None) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_generate_custom_model_credential_name", return_value="API KEY 1"): + with patch.object( + ProviderConfiguration, + "validate_custom_model_credentials", + return_value={"openai_api_key": "enc"}, + ): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.create_custom_model_credential(ModelType.LLM, "gpt-4o", {"k": "v"}, None) + assert session.add.call_count == 2 + assert mock_cache.return_value.delete.call_count == 1 + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=False): + with patch.object( + ProviderConfiguration, + "validate_custom_model_credentials", + return_value={"openai_api_key": "enc2"}, + ): + with patch.object( + ProviderConfiguration, + "_get_custom_model_record", + return_value=provider_model_record, + ): + with patch.object( + ProviderConfiguration, + "_update_load_balancing_configs_with_credential", + ) as mock_lb: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.update_custom_model_credential( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"k": "v"}, + credential_name="New Name", + credential_id="cred-1", + ) + assert credential_record.credential_name == "New Name" + assert mock_cache.return_value.delete.call_count == 1 + mock_lb.assert_called_once() + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + lb_config = SimpleNamespace(id="lb-1") + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=credential_record), + _exec_result(scalars_all=[lb_config]), + _exec_result(scalar=2), + ] + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.delete_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + assert provider_model_record.credential_id is None + assert mock_cache.return_value.delete.call_count == 2 + + +def test_add_model_credential_to_model_and_switch_custom_model_credential() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.add_model_credential_to_model(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + configuration.add_model_credential_to_model(ModelType.LLM, "gpt-4o", "cred-1") + session.add.assert_called_once() + session.commit.assert_called_once() + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with pytest.raises(ValueError, match="Can't add same credential"): + configuration.add_model_credential_to_model(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-2") + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.add_model_credential_to_model(ModelType.LLM, "gpt-4o", "cred-2") + assert provider_model_record.credential_id == "cred-2" + mock_cache.return_value.delete.assert_called_once() + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.switch_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(ValueError, match="custom model record not found"): + configuration.switch_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_model_record = SimpleNamespace(id="model-1", credential_id=None, updated_at=None) + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.switch_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + assert provider_model_record.credential_id == "cred-1" + mock_cache.return_value.delete.assert_called_once() + + +def test_delete_custom_model_and_model_setting_methods() -> None: + configuration = _build_provider_configuration() + session = Mock() + provider_model_record = SimpleNamespace(id="model-1") + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.delete_custom_model(ModelType.LLM, "gpt-4o") + session.delete.assert_called_once_with(provider_model_record) + session.commit.assert_called_once() + mock_cache.return_value.delete.assert_called_once() + + session = Mock() + existing = SimpleNamespace(enabled=False, updated_at=None) + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + assert configuration.enable_model(ModelType.LLM, "gpt-4o") is existing + assert existing.enabled is True + + session = Mock() + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=None): + created = configuration.enable_model(ModelType.LLM, "gpt-4o") + assert created.enabled is True + + session = Mock() + existing = SimpleNamespace(enabled=True, load_balancing_enabled=True, updated_at=None) + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + assert configuration.disable_model(ModelType.LLM, "gpt-4o") is existing + assert existing.enabled is False + + session = Mock() + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=None): + created = configuration.disable_model(ModelType.LLM, "gpt-4o") + assert created.enabled is False + + session = Mock() + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + result = configuration.get_provider_model_setting(ModelType.LLM, "gpt-4o") + assert result is existing + + +def test_model_load_balancing_enable_disable_and_switch_preferred_provider_type_without_session() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar.return_value = 1 + with _patched_session(session): + with pytest.raises(ValueError, match="must be more than 1"): + configuration.enable_model_load_balancing(ModelType.LLM, "gpt-4o") + + session = Mock() + session.execute.return_value.scalar.return_value = 2 + existing = SimpleNamespace(load_balancing_enabled=False, updated_at=None) + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + result = configuration.enable_model_load_balancing(ModelType.LLM, "gpt-4o") + assert result is existing + assert existing.load_balancing_enabled is True + + session = Mock() + session.execute.return_value.scalar.return_value = 2 + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=None): + created = configuration.enable_model_load_balancing(ModelType.LLM, "gpt-4o") + assert created.load_balancing_enabled is True + + session = Mock() + existing = SimpleNamespace(load_balancing_enabled=True, updated_at=None) + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + result = configuration.disable_model_load_balancing(ModelType.LLM, "gpt-4o") + assert result is existing + assert existing.load_balancing_enabled is False + + session = Mock() + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=None): + created = configuration.disable_model_load_balancing(ModelType.LLM, "gpt-4o") + assert created.load_balancing_enabled is False + + configuration.preferred_provider_type = ProviderType.SYSTEM + switch_session = Mock() + with _patched_session(switch_session): + switch_session.execute.return_value.scalars.return_value.first.return_value = None + configuration.switch_preferred_provider_type(ProviderType.CUSTOM) + assert any( + call.args and call.args[0].__class__.__name__ == "TenantPreferredModelProvider" + for call in switch_session.add.call_args_list + ) + switch_session.commit.assert_called() + + +def test_system_and_custom_provider_model_helpers_cover_remaining_skip_paths() -> None: + configuration = _build_provider_configuration() + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL], + models=[_build_ai_model("llm-model")], + ) + configuration.system_configuration.quota_configurations = [ + QuotaConfiguration( + quota_type=ProviderQuotaType.FREE, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="target", base_model_name="base", model_type=ModelType.LLM), + ], + ), + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="target", base_model_name="base", model_type=ModelType.LLM), + RestrictModel(model="error-model", base_model_name="base", model_type=ModelType.LLM), + RestrictModel(model="none-model", base_model_name="base", model_type=ModelType.LLM), + RestrictModel( + model="embed-model", + base_model_name="base", + model_type=ModelType.TEXT_EMBEDDING, + ), + ], + ), + ] + configuration.system_configuration.current_quota_type = ProviderQuotaType.TRIAL + + def _system_schema(*, model_type: ModelType, model: str, credentials: dict | None): + if model == "error-model": + raise RuntimeError("boom") + if model == "none-model": + return None + if model == "embed-model": + return _build_ai_model("embed-model", model_type=ModelType.TEXT_EMBEDDING) + return _build_ai_model("target") + + with patch( + "core.entities.provider_configuration.original_provider_configurate_methods", + {"openai": [ConfigurateMethod.CUSTOMIZABLE_MODEL]}, + ): + with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_system_schema): + system_models = configuration._get_system_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map={ + ModelType.LLM: { + "target": ModelSettings( + model="target", + model_type=ModelType.LLM, + enabled=False, + load_balancing_configs=[], + ) + } + }, + ) + assert any(model.model == "target" and model.status == ModelStatus.DISABLED for model in system_models) + + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"}) + configuration.custom_configuration.models = [ + CustomModelConfiguration( + model="skip-model-type", + model_type=ModelType.TEXT_EMBEDDING, + credentials={"k": "v"}, + ), + CustomModelConfiguration( + model="skip-unadded", + model_type=ModelType.LLM, + credentials={"k": "v"}, + unadded_to_model_list=True, + ), + CustomModelConfiguration( + model="skip-filter", + model_type=ModelType.LLM, + credentials={"k": "v"}, + ), + CustomModelConfiguration( + model="error-custom", + model_type=ModelType.LLM, + credentials={"k": "v"}, + ), + CustomModelConfiguration( + model="none-custom", + model_type=ModelType.LLM, + credentials={"k": "v"}, + ), + CustomModelConfiguration( + model="disabled-custom", + model_type=ModelType.LLM, + credentials={"k": "v"}, + ), + ] + + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[_build_ai_model("base-disabled")], + ) + model_setting_map = { + ModelType.LLM: { + "base-disabled": ModelSettings( + model="base-disabled", + model_type=ModelType.LLM, + enabled=False, + load_balancing_enabled=True, + load_balancing_configs=[ModelLoadBalancingConfiguration(id="lb-1", name="lb", credentials={})], + ), + "disabled-custom": ModelSettings( + model="disabled-custom", + model_type=ModelType.LLM, + enabled=False, + load_balancing_enabled=False, + load_balancing_configs=[], + ), + } + } + + def _custom_schema(*, model_type: ModelType, model: str, credentials: dict | None): + if model == "error-custom": + raise RuntimeError("boom") + if model == "none-custom": + return None + return _build_ai_model(model) + + with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_custom_schema): + custom_models = configuration._get_custom_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map=model_setting_map, + model="disabled-custom", + ) + assert any(model.model == "base-disabled" and model.status == ModelStatus.DISABLED for model in custom_models) + assert any(model.model == "disabled-custom" and model.status == ModelStatus.DISABLED for model in custom_models) + + +def test_get_current_credentials_skips_non_current_quota_restrictions() -> None: + configuration = _build_provider_configuration() + configuration.system_configuration.current_quota_type = ProviderQuotaType.TRIAL + configuration.system_configuration.quota_configurations = [ + QuotaConfiguration( + quota_type=ProviderQuotaType.FREE, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="gpt-4o", base_model_name="free-base", model_type=ModelType.LLM), + ], + ), + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="gpt-4o", base_model_name="trial-base", model_type=ModelType.LLM), + ], + ), + ] + + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + assert credentials["base_model_name"] == "trial-base" + + +def test_get_system_configuration_status_covers_disabled_and_quota_exceeded() -> None: + configuration = _build_provider_configuration() + configuration.system_configuration.enabled = False + assert configuration.get_system_configuration_status() == SystemConfigurationStatus.UNSUPPORTED + + configuration.system_configuration.enabled = True + configuration.system_configuration.quota_configurations = [ + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=100, + is_valid=False, + restrict_models=[], + ) + ] + configuration.system_configuration.current_quota_type = ProviderQuotaType.TRIAL + assert configuration.get_system_configuration_status() == SystemConfigurationStatus.QUOTA_EXCEEDED + + +def test_get_specific_provider_credential_decrypts_and_obfuscates_credentials() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = _build_secret_provider_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config='{"openai_api_key":"enc-secret","region":"us"}' + ) + provider_record = SimpleNamespace(provider_name="aliased-openai") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw-secret"): + with patch.object( + ProviderConfiguration, + "obfuscated_credentials", + side_effect=lambda credentials, credential_form_schemas: credentials, + ): + credentials = configuration._get_specific_provider_credential("cred-1") + + assert credentials == {"openai_api_key": "raw-secret", "region": "us"} + + +def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = _build_secret_provider_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config='{"openai_api_key":"enc-secret"}' + ) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with patch( + "core.entities.provider_configuration.encrypter.decrypt_token", + side_effect=RuntimeError("boom"), + ): + with patch("core.entities.provider_configuration.logger.exception") as mock_logger: + with patch.object( + ProviderConfiguration, + "obfuscated_credentials", + side_effect=lambda credentials, credential_form_schemas: credentials, + ): + credentials = configuration._get_specific_provider_credential("cred-1") + + assert credentials == {"openai_api_key": "enc-secret"} + mock_logger.assert_called_once() + + +def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = _build_secret_provider_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + mock_factory = Mock() + mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + session=session, + ) + + assert validated == {"openai_api_key": "enc-new"} + + +def test_create_provider_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.add.side_effect = RuntimeError("boom") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_generate_provider_credential_name", return_value="API KEY 9"): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.create_provider_credential({"api_key": "raw"}, None) + + session.rollback.assert_called_once() + + +def test_update_provider_credential_raises_on_duplicate_name() -> None: + configuration = _build_provider_configuration() + session = Mock() + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=True): + with pytest.raises(ValueError, match="already exists"): + configuration.update_provider_credential({"api_key": "raw"}, "cred-1", "Main") + + +def test_update_provider_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + id="cred-1", + encrypted_config="{}", + credential_name="Main", + updated_at=None, + ) + session.commit.side_effect = RuntimeError("boom") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.update_provider_credential({"api_key": "raw"}, "cred-1", "Main") + + session.rollback.assert_called_once() + + +def test_delete_provider_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.delete.side_effect = RuntimeError("boom") + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=SimpleNamespace(id="cred-1")), + _exec_result(scalars_all=[]), + _exec_result(scalar=2), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.delete_provider_credential("cred-1") + + session.rollback.assert_called_once() + + +def test_switch_active_provider_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(id="cred-1") + session.commit.side_effect = RuntimeError("boom") + provider_record = SimpleNamespace(id="provider-1", credential_id=None, updated_at=None) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with pytest.raises(RuntimeError, match="boom"): + configuration.switch_active_provider_credential("cred-1") + + session.rollback.assert_called_once() + + +def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None: + configuration = _build_provider_configuration() + configuration.provider.model_credential_schema = _build_secret_model_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + id="cred-1", + credential_name="Main", + encrypted_config='{"openai_api_key":"enc-secret"}', + ) + + with _patched_session(session): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")): + with patch("core.entities.provider_configuration.logger.exception") as mock_logger: + with patch.object( + ProviderConfiguration, + "obfuscated_credentials", + side_effect=lambda credentials, credential_form_schemas: credentials, + ): + result = configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + assert result["credentials"] == {"openai_api_key": "enc-secret"} + mock_logger.assert_called_once() + + +def test_validate_custom_model_credentials_handles_invalid_original_json() -> None: + configuration = _build_provider_configuration() + configuration.provider.model_credential_schema = _build_secret_model_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_factory = Mock() + mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + session=session, + ) + + assert validated == {"openai_api_key": "enc-new"} + + +def test_create_custom_model_credential_raises_on_duplicate_name() -> None: + configuration = _build_provider_configuration() + session = Mock() + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=True): + with pytest.raises(ValueError, match="already exists"): + configuration.create_custom_model_credential(ModelType.LLM, "gpt-4o", {"k": "v"}, "Main") + + +def test_create_custom_model_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.add.side_effect = RuntimeError("boom") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_generate_custom_model_credential_name", return_value="API KEY 4"): + with patch.object( + ProviderConfiguration, + "validate_custom_model_credentials", + return_value={"openai_api_key": "enc"}, + ): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.create_custom_model_credential(ModelType.LLM, "gpt-4o", {"k": "v"}, None) + + session.rollback.assert_called_once() + + +def test_update_custom_model_credential_raises_on_duplicate_name() -> None: + configuration = _build_provider_configuration() + session = Mock() + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=True): + with pytest.raises(ValueError, match="already exists"): + configuration.update_custom_model_credential( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"k": "v"}, + credential_name="Main", + credential_id="cred-1", + ) + + +def test_update_custom_model_credential_raises_when_record_not_found() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_custom_model_credentials", return_value={"k": "v"}): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.update_custom_model_credential( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"k": "v"}, + credential_name="Main", + credential_id="cred-1", + ) + + +def test_update_custom_model_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + id="cred-1", + encrypted_config="{}", + credential_name="Main", + updated_at=None, + ) + session.commit.side_effect = RuntimeError("boom") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_custom_model_credentials", return_value={"k": "v"}): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.update_custom_model_credential( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"k": "v"}, + credential_name="Main", + credential_id="cred-1", + ) + + session.rollback.assert_called_once() + + +def test_delete_custom_model_credential_raises_when_record_not_found() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.delete_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + +def test_delete_custom_model_credential_removes_custom_model_record_when_last_credential() -> None: + configuration = _build_provider_configuration() + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=credential_record), + _exec_result(scalars_all=[]), + _exec_result(scalar=1), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + configuration.delete_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + assert any(call.args and call.args[0] is provider_model_record for call in session.delete.call_args_list) + + +def test_delete_custom_model_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.delete.side_effect = RuntimeError("boom") + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=SimpleNamespace(id="cred-1")), + _exec_result(scalars_all=[]), + _exec_result(scalar=2), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.delete_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + session.rollback.assert_called_once() + + +def test_get_custom_provider_models_skips_schema_models_with_mismatched_type() -> None: + configuration = _build_provider_configuration() + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM, ModelType.TEXT_EMBEDDING], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[ + _build_ai_model("llm-model", model_type=ModelType.LLM), + _build_ai_model("embed-model", model_type=ModelType.TEXT_EMBEDDING), + ], + ) + + models = configuration._get_custom_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map={}, + ) + + assert any(model.model == "llm-model" for model in models) + assert all(model.model != "embed-model" for model in models) + + +def test_get_custom_provider_models_skips_custom_models_on_schema_error_or_none() -> None: + configuration = _build_provider_configuration() + configuration.custom_configuration.models = [ + CustomModelConfiguration(model="error-custom", model_type=ModelType.LLM, credentials={"k": "v"}), + CustomModelConfiguration(model="none-custom", model_type=ModelType.LLM, credentials={"k": "v"}), + CustomModelConfiguration(model="ok-custom", model_type=ModelType.LLM, credentials={"k": "v"}), + ] + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[], + ) + + def _schema(*, model_type: ModelType, model: str, credentials: dict | None): + if model == "error-custom": + raise RuntimeError("boom") + if model == "none-custom": + return None + return _build_ai_model(model) + + with patch("core.entities.provider_configuration.logger.warning") as mock_warning: + with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_schema): + models = configuration._get_custom_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map={}, + ) + + assert mock_warning.call_count == 1 + assert any(model.model == "ok-custom" for model in models) + assert all(model.model != "none-custom" for model in models) diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py new file mode 100644 index 0000000000..c5bfd05a1e --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -0,0 +1,72 @@ +import pytest + +from core.entities.parameter_entities import AppSelectorScope +from core.entities.provider_entities import ( + BasicProviderConfig, + ModelSettings, + ProviderConfig, + ProviderQuotaType, +) +from core.tools.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType + + +def test_provider_quota_type_value_of_returns_enum_member() -> None: + # Arrange / Act + quota_type = ProviderQuotaType.value_of(ProviderQuotaType.TRIAL.value) + + # Assert + assert quota_type == ProviderQuotaType.TRIAL + + +def test_provider_quota_type_value_of_rejects_unknown_values() -> None: + # Arrange / Act / Assert + with pytest.raises(ValueError, match="No matching enum found"): + ProviderQuotaType.value_of("enterprise") + + +def test_basic_provider_config_type_value_of_handles_known_values() -> None: + # Arrange / Act + parameter_type = BasicProviderConfig.Type.value_of("text-input") + + # Assert + assert parameter_type == BasicProviderConfig.Type.TEXT_INPUT + + +def test_basic_provider_config_type_value_of_rejects_invalid_values() -> None: + # Arrange / Act / Assert + with pytest.raises(ValueError, match="invalid mode value"): + BasicProviderConfig.Type.value_of("unknown") + + +def test_provider_config_to_basic_provider_config_keeps_type_and_name() -> None: + # Arrange + provider_config = ProviderConfig( + type=BasicProviderConfig.Type.SELECT, + name="workspace", + scope=AppSelectorScope.ALL, + options=[ProviderConfig.Option(value="all", label=I18nObject(en_US="All"))], + ) + + # Act + basic_config = provider_config.to_basic_provider_config() + + # Assert + assert isinstance(basic_config, BasicProviderConfig) + assert basic_config.type == BasicProviderConfig.Type.SELECT + assert basic_config.name == "workspace" + + +def test_model_settings_accepts_model_field_name() -> None: + # Arrange / Act + settings = ModelSettings( + model="gpt-4o", + model_type=ModelType.LLM, + enabled=True, + load_balancing_enabled=False, + load_balancing_configs=[], + ) + + # Assert + assert settings.model == "gpt-4o" + assert settings.model_type == ModelType.LLM diff --git a/api/tests/unit_tests/core/extension/test_api_based_extension_requestor.py b/api/tests/unit_tests/core/extension/test_api_based_extension_requestor.py new file mode 100644 index 0000000000..399b531205 --- /dev/null +++ b/api/tests/unit_tests/core/extension/test_api_based_extension_requestor.py @@ -0,0 +1,137 @@ +import httpx +import pytest + +from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor +from models.api_based_extension import APIBasedExtensionPoint + + +def test_request_success(mocker): + # Mock httpx.Client and its context manager + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "success"} + mock_client_instance.request.return_value = mock_response + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + result = requestor.request(APIBasedExtensionPoint.PING, {"foo": "bar"}) + + assert result == {"result": "success"} + mock_client_instance.request.assert_called_once_with( + method="POST", + url="http://example.com", + json={"point": APIBasedExtensionPoint.PING.value, "params": {"foo": "bar"}}, + headers={"Content-Type": "application/json", "Authorization": "Bearer test_key"}, + ) + + +def test_request_with_ssrf_proxy(mocker): + # Mock dify_config + mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080") + mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", "https://proxy:8081") + + # Mock httpx.Client + mock_client = mocker.MagicMock() + mock_client_class = mocker.patch("httpx.Client", return_value=mock_client) + mock_client_instance = mock_client.__enter__.return_value + + # Mock response + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "success"} + mock_client_instance.request.return_value = mock_response + + # Mock HTTPTransport + mock_transport = mocker.patch("httpx.HTTPTransport") + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + requestor.request(APIBasedExtensionPoint.PING, {}) + + # Verify httpx.Client was called with mounts + mock_client_class.assert_called_once() + kwargs = mock_client_class.call_args.kwargs + assert "mounts" in kwargs + assert "http://" in kwargs["mounts"] + assert "https://" in kwargs["mounts"] + assert mock_transport.call_count == 2 + + +def test_request_with_only_one_proxy_config(mocker): + # Mock dify_config with only one proxy + mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080") + mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", None) + + # Mock httpx.Client + mock_client = mocker.MagicMock() + mock_client_class = mocker.patch("httpx.Client", return_value=mock_client) + mock_client_instance = mock_client.__enter__.return_value + + # Mock response + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "success"} + mock_client_instance.request.return_value = mock_response + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + requestor.request(APIBasedExtensionPoint.PING, {}) + + # Verify httpx.Client was called with mounts=None (default) + mock_client_class.assert_called_once() + kwargs = mock_client_class.call_args.kwargs + assert kwargs.get("mounts") is None + + +def test_request_timeout(mocker): + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + mock_client_instance.request.side_effect = httpx.TimeoutException("timeout") + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + with pytest.raises(ValueError, match="request timeout"): + requestor.request(APIBasedExtensionPoint.PING, {}) + + +def test_request_connection_error(mocker): + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + mock_client_instance.request.side_effect = httpx.RequestError("error") + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + with pytest.raises(ValueError, match="request connection error"): + requestor.request(APIBasedExtensionPoint.PING, {}) + + +def test_request_error_status_code(mocker): + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + + mock_response = mocker.MagicMock() + mock_response.status_code = 404 + mock_response.text = "Not Found" + mock_client_instance.request.return_value = mock_response + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + with pytest.raises(ValueError, match="request error, status_code: 404, content: Not Found"): + requestor.request(APIBasedExtensionPoint.PING, {}) + + +def test_request_error_status_code_long_content(mocker): + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + + mock_response = mocker.MagicMock() + mock_response.status_code = 500 + mock_response.text = "A" * 200 # Testing truncation of content + mock_client_instance.request.return_value = mock_response + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + expected_content = "A" * 100 + with pytest.raises(ValueError, match=f"request error, status_code: 500, content: {expected_content}"): + requestor.request(APIBasedExtensionPoint.PING, {}) diff --git a/api/tests/unit_tests/core/extension/test_extensible.py b/api/tests/unit_tests/core/extension/test_extensible.py new file mode 100644 index 0000000000..9bce0cd7c8 --- /dev/null +++ b/api/tests/unit_tests/core/extension/test_extensible.py @@ -0,0 +1,281 @@ +import json +import types +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +from core.extension.extensible import Extensible + + +class TestExtensible: + def test_init(self): + tenant_id = "tenant_123" + config = {"key": "value"} + ext = Extensible(tenant_id, config) + assert ext.tenant_id == tenant_id + assert ext.config == config + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.os.path.exists") + @patch("core.extension.extensible.Path.read_text") + @patch("core.extension.extensible.importlib.util.module_from_spec") + @patch("core.extension.extensible.sort_to_dict_by_position_map") + def test_scan_extensions_success( + self, + mock_sort, + mock_module_from_spec, + mock_read_text, + mock_exists, + mock_isdir, + mock_listdir, + mock_dirname, + mock_find_spec, + ): + # Setup + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [ + ["ext1"], # package_dir + ["ext1.py", "__builtin__"], # subdir_path + ] + mock_isdir.return_value = True + + mock_exists.return_value = True + mock_read_text.return_value = "10" + + # Use types.ModuleType to avoid MagicMock __dict__ issues + mock_mod = types.ModuleType("ext1") + + class MockExtension(Extensible): + pass + + mock_mod.MockExtension = MockExtension + mock_module_from_spec.return_value = mock_mod + + mock_sort.side_effect = lambda position_map, data, name_func: data + + # Execute + results = Extensible.scan_extensions() + + # Assert + assert len(results) == 1 + assert results[0].name == "ext1" + assert results[0].position == 10 + assert results[0].builtin is True + assert results[0].extension_class == MockExtension + + @patch("core.extension.extensible.importlib.util.find_spec") + def test_scan_extensions_package_not_found(self, mock_find_spec): + mock_find_spec.return_value = None + with pytest.raises(ImportError, match="Could not find package"): + Extensible.scan_extensions() + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + def test_scan_extensions_skip_subdirs(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + mock_find_spec.return_value = package_spec + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["__pycache__", "not_a_dir", "missing_py_file"], []] + + mock_isdir.side_effect = [False, True] + + with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]): + results = Extensible.scan_extensions() + assert len(results) == 0 + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.os.path.exists") + @patch("core.extension.extensible.importlib.util.module_from_spec") + def test_scan_extensions_not_builtin_success( + self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec + ): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["ext1"], ["ext1.py", "schema.json"]] + mock_isdir.return_value = True + + # exists checks: only schema.json needs to exist + mock_exists.return_value = True + + mock_mod = types.ModuleType("ext1") + + class MockExtension(Extensible): + pass + + mock_mod.MockExtension = MockExtension + mock_module_from_spec.return_value = mock_mod + + schema_content = json.dumps({"label": {"en": "Test"}, "form_schema": [{"name": "field1"}]}) + + with ( + patch("builtins.open", mock_open(read_data=schema_content)), + patch( + "core.extension.extensible.sort_to_dict_by_position_map", + side_effect=lambda position_map, data, name_func: data, + ), + ): + results = Extensible.scan_extensions() + + assert len(results) == 1 + assert results[0].name == "ext1" + assert results[0].builtin is False + assert results[0].label == {"en": "Test"} + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.os.path.exists") + @patch("core.extension.extensible.importlib.util.module_from_spec") + def test_scan_extensions_not_builtin_missing_schema( + self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec + ): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["ext1"], ["ext1.py"]] + mock_isdir.return_value = True + + # exists: only schema.json checked, and return False + mock_exists.return_value = False + + mock_mod = types.ModuleType("ext1") + + class MockExtension(Extensible): + pass + + mock_mod.MockExtension = MockExtension + mock_module_from_spec.return_value = mock_mod + + with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]): + results = Extensible.scan_extensions() + + assert len(results) == 0 + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.importlib.util.module_from_spec") + @patch("core.extension.extensible.os.path.exists") + def test_scan_extensions_no_extension_class( + self, mock_exists, mock_module_from_spec, mock_isdir, mock_listdir, mock_dirname, mock_find_spec + ): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["ext1"], ["ext1.py"]] + mock_isdir.return_value = True + + # Mock not builtin + mock_exists.return_value = False + + mock_mod = types.ModuleType("ext1") + mock_mod.SomeOtherClass = type("SomeOtherClass", (), {}) + mock_module_from_spec.return_value = mock_mod + + # We need to ensure we don't crash if checking schema (but we won't reach there because class not found) + + with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]): + results = Extensible.scan_extensions() + + assert len(results) == 0 + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + def test_scan_extensions_module_import_error(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + + mock_find_spec.side_effect = [package_spec, None] # No module spec + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["ext1"], ["ext1.py"]] + mock_isdir.return_value = True + + with pytest.raises(ImportError, match="Failed to load module"): + Extensible.scan_extensions() + + @patch("core.extension.extensible.importlib.util.find_spec") + def test_scan_extensions_general_exception(self, mock_find_spec): + mock_find_spec.side_effect = Exception("Unexpected error") + with pytest.raises(Exception, match="Unexpected error"): + Extensible.scan_extensions() + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.os.path.exists") + @patch("core.extension.extensible.Path.read_text") + @patch("core.extension.extensible.importlib.util.module_from_spec") + def test_scan_extensions_builtin_without_position_file( + self, mock_module_from_spec, mock_read_text, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec + ): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + mock_listdir.side_effect = [["ext1"], ["ext1.py", "__builtin__"]] + mock_isdir.return_value = True + + # builtin exists in listdir, but os.path.exists(builtin_file_path) returns False + mock_exists.return_value = False + + mock_mod = types.ModuleType("ext1") + + class MockExtension(Extensible): + pass + + mock_mod.MockExtension = MockExtension + mock_module_from_spec.return_value = mock_mod + + with patch( + "core.extension.extensible.sort_to_dict_by_position_map", + side_effect=lambda position_map, data, name_func: data, + ): + results = Extensible.scan_extensions() + + assert len(results) == 1 + assert results[0].position == 0 diff --git a/api/tests/unit_tests/core/extension/test_extension.py b/api/tests/unit_tests/core/extension/test_extension.py new file mode 100644 index 0000000000..4ad32d3840 --- /dev/null +++ b/api/tests/unit_tests/core/extension/test_extension.py @@ -0,0 +1,90 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.extension.extensible import ExtensionModule, ModuleExtension +from core.extension.extension import Extension + + +class TestExtension: + def setup_method(self): + # Reset the private class attribute before each test + Extension._Extension__module_extensions = {} + + def test_init(self): + # Mock scan_extensions for Moderation and ExternalDataTool + mock_mod_extensions = {"mod1": ModuleExtension(name="mod1")} + mock_ext_extensions = {"ext1": ModuleExtension(name="ext1")} + + extension = Extension() + + # We need to mock scan_extensions on the classes defined in Extension.module_classes + with ( + patch("core.extension.extension.Moderation.scan_extensions", return_value=mock_mod_extensions), + patch("core.extension.extension.ExternalDataTool.scan_extensions", return_value=mock_ext_extensions), + ): + extension.init() + + # Check if internal state is updated + internal_state = Extension._Extension__module_extensions + assert internal_state[ExtensionModule.MODERATION.value] == mock_mod_extensions + assert internal_state[ExtensionModule.EXTERNAL_DATA_TOOL.value] == mock_ext_extensions + + def test_module_extensions_success(self): + # Setup data + mock_extensions = {"name1": ModuleExtension(name="name1"), "name2": ModuleExtension(name="name2")} + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: mock_extensions} + + extension = Extension() + result = extension.module_extensions(ExtensionModule.MODERATION.value) + + assert len(result) == 2 + assert any(e.name == "name1" for e in result) + assert any(e.name == "name2" for e in result) + + def test_module_extensions_not_found(self): + extension = Extension() + with pytest.raises(ValueError, match="Extension Module unknown not found"): + extension.module_extensions("unknown") + + def test_module_extension_success(self): + mock_ext = ModuleExtension(name="test_ext") + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}} + + extension = Extension() + result = extension.module_extension(ExtensionModule.MODERATION, "test_ext") + assert result == mock_ext + + def test_module_extension_module_not_found(self): + extension = Extension() + # ExtensionModule.MODERATION is "moderation" + with pytest.raises(ValueError, match="Extension Module moderation not found"): + extension.module_extension(ExtensionModule.MODERATION, "any") + + def test_module_extension_extension_not_found(self): + # We need a non-empty dict because 'if not module_extensions' in extension.py + # returns True for an empty dict, which raises the module not found error instead. + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"other": MagicMock()}} + + extension = Extension() + with pytest.raises(ValueError, match="Extension unknown not found"): + extension.module_extension(ExtensionModule.MODERATION, "unknown") + + def test_extension_class_success(self): + class MockClass: + pass + + mock_ext = ModuleExtension(name="test_ext", extension_class=MockClass) + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}} + + extension = Extension() + result = extension.extension_class(ExtensionModule.MODERATION, "test_ext") + assert result == MockClass + + def test_extension_class_none(self): + mock_ext = ModuleExtension(name="test_ext", extension_class=None) + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}} + + extension = Extension() + with pytest.raises(AssertionError): + extension.extension_class(ExtensionModule.MODERATION, "test_ext") diff --git a/api/tests/unit_tests/core/external_data_tool/api/test_api.py b/api/tests/unit_tests/core/external_data_tool/api/test_api.py new file mode 100644 index 0000000000..1653124bd8 --- /dev/null +++ b/api/tests/unit_tests/core/external_data_tool/api/test_api.py @@ -0,0 +1,145 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.external_data_tool.api.api import ApiExternalDataTool +from models.api_based_extension import APIBasedExtensionPoint + + +def test_api_external_data_tool_name(): + assert ApiExternalDataTool.name == "api" + + +@patch("core.external_data_tool.api.api.db") +def test_validate_config_success(mock_db): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_db.session.scalar.return_value = mock_extension + + # Should not raise exception + ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"}) + + +def test_validate_config_missing_id(): + with pytest.raises(ValueError, match="api_based_extension_id is required"): + ApiExternalDataTool.validate_config("tenant_id", {}) + + +@patch("core.external_data_tool.api.api.db") +def test_validate_config_invalid_id(mock_db): + mock_db.session.scalar.return_value = None + + with pytest.raises(ValueError, match="api_based_extension_id is invalid"): + ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"}) + + +@pytest.fixture +def api_tool(): + # Use standard kwargs as it inherits from ExternalDataTool which is typically a Pydantic BaseModel + return ApiExternalDataTool( + tenant_id="tenant_id", app_id="app_id", variable="var1", config={"api_based_extension_id": "ext_id"} + ) + + +@patch("core.external_data_tool.api.api.db") +@patch("core.external_data_tool.api.api.encrypter") +@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor") +def test_query_success(mock_requestor_class, mock_encrypter, mock_db, api_tool): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_extension.api_endpoint = "http://api" + mock_extension.api_key = "encrypted_key" + mock_db.session.scalar.return_value = mock_extension + mock_encrypter.decrypt_token.return_value = "decrypted_key" + + mock_requestor = mock_requestor_class.return_value + mock_requestor.request.return_value = {"result": "success_result"} + + res = api_tool.query({"input1": "value1"}, "query_str") + + assert res == "success_result" + + mock_requestor_class.assert_called_once_with(api_endpoint="http://api", api_key="decrypted_key") + mock_requestor.request.assert_called_once_with( + point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, + params={"app_id": "app_id", "tool_variable": "var1", "inputs": {"input1": "value1"}, "query": "query_str"}, + ) + + +def test_query_missing_config(): + api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1") + api_tool.config = None # Force None + with pytest.raises(ValueError, match="config is required"): + api_tool.query({}, "") + + +def test_query_missing_extension_id(): + api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1", config={"dummy": "value"}) + with pytest.raises(AssertionError, match="api_based_extension_id is required"): + api_tool.query({}, "") + + +@patch("core.external_data_tool.api.api.db") +def test_query_invalid_extension(mock_db, api_tool): + mock_db.session.scalar.return_value = None + + with pytest.raises(ValueError, match=".*error: api_based_extension_id is invalid"): + api_tool.query({}, "") + + +@patch("core.external_data_tool.api.api.db") +@patch("core.external_data_tool.api.api.encrypter") +@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor") +def test_query_requestor_init_error(mock_requestor_class, mock_encrypter, mock_db, api_tool): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_extension.api_endpoint = "http://api" + mock_extension.api_key = "encrypted_key" + mock_db.session.scalar.return_value = mock_extension + mock_encrypter.decrypt_token.return_value = "decrypted_key" + + mock_requestor_class.side_effect = Exception("init error") + + with pytest.raises(ValueError, match=".*error: init error"): + api_tool.query({}, "") + + +@patch("core.external_data_tool.api.api.db") +@patch("core.external_data_tool.api.api.encrypter") +@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor") +def test_query_no_result_in_response(mock_requestor_class, mock_encrypter, mock_db, api_tool): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_extension.api_endpoint = "http://api" + mock_extension.api_key = "encrypted_key" + mock_db.session.scalar.return_value = mock_extension + mock_encrypter.decrypt_token.return_value = "decrypted_key" + + mock_requestor = mock_requestor_class.return_value + mock_requestor.request.return_value = {"other": "value"} + + with pytest.raises(ValueError, match=".*error: result not found in response"): + api_tool.query({}, "") + + +@patch("core.external_data_tool.api.api.db") +@patch("core.external_data_tool.api.api.encrypter") +@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor") +def test_query_result_not_string(mock_requestor_class, mock_encrypter, mock_db, api_tool): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_extension.api_endpoint = "http://api" + mock_extension.api_key = "encrypted_key" + mock_db.session.scalar.return_value = mock_extension + mock_encrypter.decrypt_token.return_value = "decrypted_key" + + mock_requestor = mock_requestor_class.return_value + mock_requestor.request.return_value = {"result": 123} # Not a string + + with pytest.raises(ValueError, match=".*error: result is not string"): + api_tool.query({}, "") diff --git a/api/tests/unit_tests/core/external_data_tool/test_base.py b/api/tests/unit_tests/core/external_data_tool/test_base.py new file mode 100644 index 0000000000..216cda83c5 --- /dev/null +++ b/api/tests/unit_tests/core/external_data_tool/test_base.py @@ -0,0 +1,66 @@ +import pytest + +from core.extension.extensible import ExtensionModule +from core.external_data_tool.base import ExternalDataTool + + +class TestExternalDataTool: + def test_module_attribute(self): + assert ExternalDataTool.module == ExtensionModule.EXTERNAL_DATA_TOOL + + def test_init(self): + # Create a concrete subclass to test init + class ConcreteTool(ExternalDataTool): + @classmethod + def validate_config(cls, tenant_id: str, config: dict): + return super().validate_config(tenant_id, config) + + def query(self, inputs: dict, query: str | None = None) -> str: + return super().query(inputs, query) + + tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"}) + assert tool.tenant_id == "tenant_1" + assert tool.app_id == "app_1" + assert tool.variable == "var_1" + assert tool.config == {"key": "value"} + + def test_init_without_config(self): + # Create a concrete subclass to test init + class ConcreteTool(ExternalDataTool): + @classmethod + def validate_config(cls, tenant_id: str, config: dict): + pass + + def query(self, inputs: dict, query: str | None = None) -> str: + return "" + + tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") + assert tool.tenant_id == "tenant_1" + assert tool.app_id == "app_1" + assert tool.variable == "var_1" + assert tool.config is None + + def test_validate_config_raises_not_implemented(self): + class ConcreteTool(ExternalDataTool): + @classmethod + def validate_config(cls, tenant_id: str, config: dict): + return super().validate_config(tenant_id, config) + + def query(self, inputs: dict, query: str | None = None) -> str: + return "" + + with pytest.raises(NotImplementedError): + ConcreteTool.validate_config("tenant_1", {}) + + def test_query_raises_not_implemented(self): + class ConcreteTool(ExternalDataTool): + @classmethod + def validate_config(cls, tenant_id: str, config: dict): + pass + + def query(self, inputs: dict, query: str | None = None) -> str: + return super().query(inputs, query) + + tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") + with pytest.raises(NotImplementedError): + tool.query({}) diff --git a/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py b/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py new file mode 100644 index 0000000000..86b461cf04 --- /dev/null +++ b/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py @@ -0,0 +1,115 @@ +from unittest.mock import patch + +import pytest +from flask import Flask + +from core.app.app_config.entities import ExternalDataVariableEntity +from core.external_data_tool.external_data_fetch import ExternalDataFetch + + +class TestExternalDataFetch: + @pytest.fixture + def app(self): + app = Flask(__name__) + return app + + def test_fetch_success(self, app): + with app.app_context(): + fetcher = ExternalDataFetch() + + # Setup mocks + tool1 = ExternalDataVariableEntity(variable="var1", type="type1", config={"c1": "v1"}) + tool2 = ExternalDataVariableEntity(variable="var2", type="type2", config={"c2": "v2"}) + + external_data_tools = [tool1, tool2] + inputs = {"input_key": "input_value"} + query = "test query" + + with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory: + # Create distinct mock instances for each tool to ensure deterministic results + # This approach is robust regardless of thread scheduling order + from unittest.mock import MagicMock + + def factory_side_effect(*args, **kwargs): + variable = kwargs.get("variable") + mock_instance = MagicMock() + if variable == "var1": + mock_instance.query.return_value = "result1" + elif variable == "var2": + mock_instance.query.return_value = "result2" + return mock_instance + + MockFactory.side_effect = factory_side_effect + + result_inputs = fetcher.fetch( + tenant_id="tenant1", + app_id="app1", + external_data_tools=external_data_tools, + inputs=inputs, + query=query, + ) + + # Each tool gets its deterministic result regardless of thread completion order + assert result_inputs["var1"] == "result1" + assert result_inputs["var2"] == "result2" + assert result_inputs["input_key"] == "input_value" + assert len(result_inputs) == 3 + + # Verify factory calls + assert MockFactory.call_count == 2 + MockFactory.assert_any_call( + name="type1", tenant_id="tenant1", app_id="app1", variable="var1", config={"c1": "v1"} + ) + MockFactory.assert_any_call( + name="type2", tenant_id="tenant1", app_id="app1", variable="var2", config={"c2": "v2"} + ) + + def test_fetch_no_tools(self): + # We don't necessarily need app_context if there are no tools, + # but fetch calls current_app._get_current_object() only inside the loop. + # Wait, let's look at the code. + # for tool in external_data_tools: + # executor.submit(..., current_app._get_current_object(), ...) + # So if external_data_tools is empty, it shouldn't access current_app. + fetcher = ExternalDataFetch() + inputs = {"input_key": "input_value"} + result_inputs = fetcher.fetch( + tenant_id="tenant1", app_id="app1", external_data_tools=[], inputs=inputs, query="test query" + ) + assert result_inputs == inputs + assert result_inputs is not inputs # Should be a copy + + def test_fetch_with_none_variable(self, app): + with app.app_context(): + fetcher = ExternalDataFetch() + tool = ExternalDataVariableEntity(variable="var1", type="type1", config={}) + + # Patch _query_external_data_tool to return None variable + with patch.object(ExternalDataFetch, "_query_external_data_tool") as mock_query: + mock_query.return_value = (None, "some_result") + + result_inputs = fetcher.fetch( + tenant_id="t1", app_id="a1", external_data_tools=[tool], inputs={"in": "val"}, query="q" + ) + + assert "var1" not in result_inputs + assert result_inputs == {"in": "val"} + + def test_query_external_data_tool(self, app): + fetcher = ExternalDataFetch() + tool = ExternalDataVariableEntity(variable="var1", type="type1", config={"k": "v"}) + + with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory: + mock_factory_instance = MockFactory.return_value + mock_factory_instance.query.return_value = "query_result" + + var, res = fetcher._query_external_data_tool( + flask_app=app, tenant_id="t1", app_id="a1", external_data_tool=tool, inputs={"i": "v"}, query="q" + ) + + assert var == "var1" + assert res == "query_result" + MockFactory.assert_called_once_with( + name="type1", tenant_id="t1", app_id="a1", variable="var1", config={"k": "v"} + ) + mock_factory_instance.query.assert_called_once_with(inputs={"i": "v"}, query="q") diff --git a/api/tests/unit_tests/core/external_data_tool/test_factory.py b/api/tests/unit_tests/core/external_data_tool/test_factory.py new file mode 100644 index 0000000000..6bb384b0ac --- /dev/null +++ b/api/tests/unit_tests/core/external_data_tool/test_factory.py @@ -0,0 +1,58 @@ +from unittest.mock import MagicMock, patch + +from core.extension.extensible import ExtensionModule +from core.external_data_tool.factory import ExternalDataToolFactory + + +def test_external_data_tool_factory_init(): + with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension: + mock_extension_class = MagicMock() + mock_code_based_extension.extension_class.return_value = mock_extension_class + + name = "test_tool" + tenant_id = "tenant_123" + app_id = "app_456" + variable = "var_v" + config = {"key": "value"} + + factory = ExternalDataToolFactory(name, tenant_id, app_id, variable, config) + + mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name) + mock_extension_class.assert_called_once_with( + tenant_id=tenant_id, app_id=app_id, variable=variable, config=config + ) + + +def test_external_data_tool_factory_validate_config(): + with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension: + mock_extension_class = MagicMock() + mock_code_based_extension.extension_class.return_value = mock_extension_class + + name = "test_tool" + tenant_id = "tenant_123" + config = {"key": "value"} + + ExternalDataToolFactory.validate_config(name, tenant_id, config) + + mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name) + mock_extension_class.validate_config.assert_called_once_with(tenant_id, config) + + +def test_external_data_tool_factory_query(): + with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension: + mock_extension_class = MagicMock() + mock_extension_instance = MagicMock() + mock_extension_class.return_value = mock_extension_instance + mock_code_based_extension.extension_class.return_value = mock_extension_class + + mock_extension_instance.query.return_value = "query_result" + + factory = ExternalDataToolFactory("name", "tenant", "app", "var", {}) + + inputs = {"input_key": "input_value"} + query = "search_query" + + result = factory.query(inputs, query) + + assert result == "query_result" + mock_extension_instance.query.assert_called_once_with(inputs, query) diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_rule_config_generator.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_rule_config_generator.py new file mode 100644 index 0000000000..b2783bdf99 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_rule_config_generator.py @@ -0,0 +1,103 @@ +import pytest + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser +from core.llm_generator.prompts import ( + RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, +) + + +class TestRuleConfigGeneratorOutputParser: + def test_get_format_instructions(self): + parser = RuleConfigGeneratorOutputParser() + instructions = parser.get_format_instructions() + assert instructions == ( + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, + RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, + ) + + def test_parse_success(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": "This is a prompt", + "variables": ["var1", "var2"], + "opening_statement": "Hello!" +} +``` +""" + result = parser.parse(text) + assert result["prompt"] == "This is a prompt" + assert result["variables"] == ["var1", "var2"] + assert result["opening_statement"] == "Hello!" + + def test_parse_invalid_json(self): + parser = RuleConfigGeneratorOutputParser() + text = "invalid json" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "Parsing text" in str(excinfo.value) + assert "could not find json block in the output" in str(excinfo.value) + + def test_parse_missing_keys(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": "This is a prompt", + "variables": ["var1", "var2"] +} +``` +""" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "expected key `opening_statement` to be present" in str(excinfo.value) + + def test_parse_wrong_type_prompt(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": 123, + "variables": ["var1", "var2"], + "opening_statement": "Hello!" +} +``` +""" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "Expected 'prompt' to be a string" in str(excinfo.value) + + def test_parse_wrong_type_variables(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": "This is a prompt", + "variables": "not a list", + "opening_statement": "Hello!" +} +``` +""" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "Expected 'variables' to be a list" in str(excinfo.value) + + def test_parse_wrong_type_opening_statement(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": "This is a prompt", + "variables": ["var1", "var2"], + "opening_statement": 123 +} +``` +""" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "Expected 'opening_statement' to be a str" in str(excinfo.value) diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py new file mode 100644 index 0000000000..46c9dc6f9c --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py @@ -0,0 +1,402 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import ( + ResponseFormat, + _handle_native_json_schema, + _handle_prompt_based_schema, + _parse_structured_output, + _prepare_schema_for_model, + _set_response_format, + convert_boolean_to_string, + invoke_llm_with_structured_output, + remove_additional_properties, +) +from core.model_manager import ModelInstance +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultWithStructuredOutput, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType + + +class TestStructuredOutput: + def test_remove_additional_properties(self): + schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "additionalProperties": False, + "nested": {"type": "object", "additionalProperties": True}, + "items": [{"type": "object", "additionalProperties": False}], + } + remove_additional_properties(schema) + assert "additionalProperties" not in schema + assert "additionalProperties" not in schema["nested"] + assert "additionalProperties" not in schema["items"][0] + + # Test with non-dict input + remove_additional_properties(None) # Should not raise + remove_additional_properties([]) # Should not raise + + def test_convert_boolean_to_string(self): + schema = { + "type": "object", + "properties": { + "is_active": {"type": "boolean"}, + "tags": {"type": "array", "items": {"type": "boolean"}}, + "list_schema": [{"type": "boolean"}], + }, + } + convert_boolean_to_string(schema) + assert schema["properties"]["is_active"]["type"] == "string" + assert schema["properties"]["tags"]["items"]["type"] == "string" + assert schema["properties"]["list_schema"][0]["type"] == "string" + + # Test with non-dict input + convert_boolean_to_string(None) # Should not raise + convert_boolean_to_string([]) # Should not raise + + def test_parse_structured_output_valid(self): + text = '{"key": "value"}' + assert _parse_structured_output(text) == {"key": "value"} + + def test_parse_structured_output_non_dict_valid_json(self): + # Even if it's valid JSON, if it's not a dict, it should try repair or fail + text = '["a", "b"]' + with patch("json_repair.loads") as mock_repair: + mock_repair.return_value = {"key": "value"} + assert _parse_structured_output(text) == {"key": "value"} + + def test_parse_structured_output_not_dict_fail_via_validate(self): + # Force TypeAdapter to return a non-dict to trigger line 292 + with patch("pydantic.TypeAdapter.validate_json") as mock_validate: + mock_validate.return_value = ["a list"] + with pytest.raises(OutputParserError) as excinfo: + _parse_structured_output('["a list"]') + assert "Failed to parse structured output" in str(excinfo.value) + + def test_parse_structured_output_repair_success(self): + text = "{'key': 'value'}" # Invalid JSON (single quotes) + # json_repair should handle this + assert _parse_structured_output(text) == {"key": "value"} + + def test_parse_structured_output_repair_list(self): + # Deepseek-r1 case: result is a list containing a dict + text = '[{"key": "value"}]' + assert _parse_structured_output(text) == {"key": "value"} + + def test_parse_structured_output_repair_list_no_dict(self): + # Deepseek-r1 case: result is a list with NO dict + text = "[1, 2, 3]" + assert _parse_structured_output(text) == {} + + def test_parse_structured_output_repair_fail(self): + text = "not a json at all" + with patch("json_repair.loads") as mock_repair: + mock_repair.return_value = "still not a dict or list" + with pytest.raises(OutputParserError): + _parse_structured_output(text) + + def test_set_response_format(self): + # Test JSON + params = {} + rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON], + ) + ] + _set_response_format(params, rules) + assert params["response_format"] == ResponseFormat.JSON + + # Test JSON_OBJECT + params = {} + rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON_OBJECT], + ) + ] + _set_response_format(params, rules) + assert params["response_format"] == ResponseFormat.JSON_OBJECT + + def test_handle_native_json_schema(self): + provider = "openai" + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "gpt-4" + structured_output_schema = {"type": "object"} + model_parameters = {} + rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON_SCHEMA], + ) + ] + + updated_params = _handle_native_json_schema( + provider, model_schema, structured_output_schema, model_parameters, rules + ) + + assert "json_schema" in updated_params + assert json.loads(updated_params["json_schema"]) == {"schema": {"type": "object"}, "name": "llm_response"} + assert updated_params["response_format"] == ResponseFormat.JSON_SCHEMA + + def test_handle_native_json_schema_no_format_rule(self): + provider = "openai" + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "gpt-4" + structured_output_schema = {"type": "object"} + model_parameters = {} + rules = [] + + updated_params = _handle_native_json_schema( + provider, model_schema, structured_output_schema, model_parameters, rules + ) + + assert "json_schema" in updated_params + assert "response_format" not in updated_params + + def test_handle_prompt_based_schema_with_system_prompt(self): + prompt_messages = [ + SystemPromptMessage(content="Existing system prompt"), + UserPromptMessage(content="User question"), + ] + schema = {"type": "object"} + + result = _handle_prompt_based_schema(prompt_messages, schema) + + assert len(result) == 2 + assert isinstance(result[0], SystemPromptMessage) + assert "Existing system prompt" in result[0].content + assert json.dumps(schema) in result[0].content + assert isinstance(result[1], UserPromptMessage) + + def test_handle_prompt_based_schema_without_system_prompt(self): + prompt_messages = [UserPromptMessage(content="User question")] + schema = {"type": "object"} + + result = _handle_prompt_based_schema(prompt_messages, schema) + + assert len(result) == 2 + assert isinstance(result[0], SystemPromptMessage) + assert json.dumps(schema) in result[0].content + assert isinstance(result[1], UserPromptMessage) + + def test_prepare_schema_for_model_gemini(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "gemini-1.5-pro" + schema = {"type": "object", "additionalProperties": False} + + result = _prepare_schema_for_model("google", model_schema, schema) + assert "additionalProperties" not in result + + def test_prepare_schema_for_model_ollama(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "llama3" + schema = {"type": "object"} + + result = _prepare_schema_for_model("ollama", model_schema, schema) + assert result == schema + + def test_prepare_schema_for_model_default(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "gpt-4" + schema = {"type": "object"} + + result = _prepare_schema_for_model("openai", model_schema, schema) + assert result == {"schema": schema, "name": "llm_response"} + + def test_invoke_llm_with_structured_output_no_stream_native(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = True + model_schema.parameter_rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON_SCHEMA], + ) + ] + model_schema.model = "gpt-4o" + + model_instance = MagicMock(spec=ModelInstance) + mock_result = MagicMock(spec=LLMResult) + mock_result.message = AssistantPromptMessage(content='{"result": "success"}') + mock_result.model = "gpt-4o" + mock_result.usage = LLMUsage.empty_usage() + mock_result.system_fingerprint = "fp_native" + mock_result.prompt_messages = [UserPromptMessage(content="hi")] + + model_instance.invoke_llm.return_value = mock_result + + result = invoke_llm_with_structured_output( + provider="openai", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[UserPromptMessage(content="hi")], + json_schema={"type": "object"}, + stream=False, + ) + + assert isinstance(result, LLMResultWithStructuredOutput) + assert result.structured_output == {"result": "success"} + assert result.system_fingerprint == "fp_native" + + def test_invoke_llm_with_structured_output_no_stream_prompt_based(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = False + model_schema.parameter_rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON], + ) + ] + model_schema.model = "claude-3" + + model_instance = MagicMock(spec=ModelInstance) + mock_result = MagicMock(spec=LLMResult) + mock_result.message = AssistantPromptMessage(content='{"result": "success"}') + mock_result.model = "claude-3" + mock_result.usage = LLMUsage.empty_usage() + mock_result.system_fingerprint = "fp_prompt" + mock_result.prompt_messages = [] + + model_instance.invoke_llm.return_value = mock_result + + result = invoke_llm_with_structured_output( + provider="anthropic", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[UserPromptMessage(content="hi")], + json_schema={"type": "object"}, + stream=False, + ) + + assert isinstance(result, LLMResultWithStructuredOutput) + assert result.structured_output == {"result": "success"} + assert result.system_fingerprint == "fp_prompt" + + def test_invoke_llm_with_structured_output_no_string_error(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = False + model_schema.parameter_rules = [] + + model_instance = MagicMock(spec=ModelInstance) + mock_result = MagicMock(spec=LLMResult) + mock_result.message = AssistantPromptMessage(content=[TextPromptMessageContent(data="not a string")]) + + model_instance.invoke_llm.return_value = mock_result + + with pytest.raises(OutputParserError) as excinfo: + invoke_llm_with_structured_output( + provider="anthropic", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[], + json_schema={}, + stream=False, + ) + assert "Failed to parse structured output, LLM result is not a string" in str(excinfo.value) + + def test_invoke_llm_with_structured_output_stream(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = False + model_schema.parameter_rules = [] + model_schema.model = "gpt-4" + + model_instance = MagicMock(spec=ModelInstance) + + # Mock chunks + chunk1 = MagicMock(spec=LLMResultChunk) + chunk1.delta = LLMResultChunkDelta( + index=0, message=AssistantPromptMessage(content='{"key": '), usage=LLMUsage.empty_usage() + ) + chunk1.prompt_messages = [UserPromptMessage(content="hi")] + chunk1.system_fingerprint = "fp1" + + chunk2 = MagicMock(spec=LLMResultChunk) + chunk2.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content='"value"}')) + chunk2.prompt_messages = [UserPromptMessage(content="hi")] + chunk2.system_fingerprint = "fp1" + + chunk3 = MagicMock(spec=LLMResultChunk) + chunk3.delta = LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[ + TextPromptMessageContent(data=" "), + ] + ), + ) + chunk3.prompt_messages = [UserPromptMessage(content="hi")] + chunk3.system_fingerprint = "fp1" + + event4 = MagicMock() + event4.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content="")) + + model_instance.invoke_llm.return_value = [chunk1, chunk2, chunk3, event4] + + generator = invoke_llm_with_structured_output( + provider="openai", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[UserPromptMessage(content="hi")], + json_schema={}, + stream=True, + ) + + chunks = list(generator) + assert len(chunks) == 5 + assert chunks[-1].structured_output == {"key": "value"} + assert chunks[-1].system_fingerprint == "fp1" + assert chunks[-1].prompt_messages == [UserPromptMessage(content="hi")] + + def test_invoke_llm_with_structured_output_stream_no_id_events(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = False + model_schema.parameter_rules = [] + model_schema.model = "gpt-4" + + model_instance = MagicMock(spec=ModelInstance) + model_instance.invoke_llm.return_value = [] + + generator = invoke_llm_with_structured_output( + provider="openai", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[], + json_schema={}, + stream=True, + ) + + with pytest.raises(OutputParserError): + list(generator) + + def test_parse_structured_output_empty_string(self): + with pytest.raises(OutputParserError): + _parse_structured_output("") diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py new file mode 100644 index 0000000000..5b7640696f --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -0,0 +1,589 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from core.app.app_config.entities import ModelConfig +from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload +from core.llm_generator.llm_generator import LLMGenerator +from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError + + +class TestLLMGenerator: + @pytest.fixture + def mock_model_instance(self): + with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + instance = MagicMock() + mock_manager.return_value.get_default_model_instance.return_value = instance + mock_manager.return_value.get_model_instance.return_value = instance + yield instance + + @pytest.fixture + def model_config_entity(self): + return ModelConfig(provider="openai", name="gpt-4", mode=LLMMode.CHAT, completion_params={"temperature": 0.7}) + + def test_generate_conversation_name_success(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Test Conversation Name"}) + mock_model_instance.invoke_llm.return_value = mock_response + + with patch("core.llm_generator.llm_generator.TraceQueueManager") as mock_trace: + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "Test Conversation Name" + mock_trace.assert_called_once() + + def test_generate_conversation_name_truncated(self, mock_model_instance): + long_query = "a" * 2100 + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Short Name"}) + mock_model_instance.invoke_llm.return_value = mock_response + + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", long_query) + assert name == "Short Name" + + def test_generate_conversation_name_empty_answer(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "" + mock_model_instance.invoke_llm.return_value = mock_response + + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "" + + def test_generate_conversation_name_json_repair(self, mock_model_instance): + mock_response = MagicMock() + # Invalid JSON that json_repair can fix + mock_response.message.get_text_content.return_value = "{'Your Output': 'Repaired Name'}" + mock_model_instance.invoke_llm.return_value = mock_response + + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "Repaired Name" + + def test_generate_conversation_name_not_dict_result(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '["not a dict"]' + mock_model_instance.invoke_llm.return_value = mock_response + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "test query" + + def test_generate_conversation_name_no_output_in_dict(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"something": "else"}' + mock_model_instance.invoke_llm.return_value = mock_response + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "test query" + + def test_generate_conversation_name_long_output(self, mock_model_instance): + long_output = "a" * 100 + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = json.dumps({"Your Output": long_output}) + mock_model_instance.invoke_llm.return_value = mock_response + + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert len(name) == 78 # 75 + "..." + assert name.endswith("...") + + def test_generate_suggested_questions_after_answer_success(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '["Question 1?", "Question 2?"]' + mock_model_instance.invoke_llm.return_value = mock_response + + questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") + assert len(questions) == 2 + assert questions[0] == "Question 1?" + + def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance): + with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + mock_manager.return_value.get_default_model_instance.side_effect = InvokeAuthorizationError("Auth failed") + questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") + assert questions == [] + + def test_generate_suggested_questions_after_answer_invoke_error(self, mock_model_instance): + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed") + questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") + assert questions == [] + + def test_generate_suggested_questions_after_answer_exception(self, mock_model_instance): + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") + assert questions == [] + + def test_generate_rule_config_no_variable_success(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=True + ) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "Generated Prompt" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert result["prompt"] == "Generated Prompt" + assert result["error"] == "" + + def test_generate_rule_config_no_variable_invoke_error(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=True + ) + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed") + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate rule config" in result["error"] + + def test_generate_rule_config_no_variable_exception(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=True + ) + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate rule config" in result["error"] + assert "Random error" in result["error"] + + def test_generate_rule_config_with_variable_success(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + # Mocking 3 calls for invoke_llm + mock_res1 = MagicMock() + mock_res1.message.get_text_content.return_value = "Step 1 Prompt" + + mock_res2 = MagicMock() + mock_res2.message.get_text_content.return_value = '"var1", "var2"' + + mock_res3 = MagicMock() + mock_res3.message.get_text_content.return_value = "Opening Statement" + + mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, mock_res3] + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert result["prompt"] == "Step 1 Prompt" + assert result["variables"] == ["var1", "var2"] + assert result["opening_statement"] == "Opening Statement" + assert result["error"] == "" + + def test_generate_rule_config_with_variable_step1_error(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + mock_model_instance.invoke_llm.side_effect = InvokeError("Step 1 Failed") + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate prefix prompt" in result["error"] + + def test_generate_rule_config_with_variable_step2_error(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + mock_res1 = MagicMock() + mock_res1.message.get_text_content.return_value = "Step 1 Prompt" + + # Step 2 fails + mock_model_instance.invoke_llm.side_effect = [mock_res1, InvokeError("Step 2 Failed"), MagicMock()] + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate variables" in result["error"] + + def test_generate_rule_config_with_variable_step3_error(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + mock_res1 = MagicMock() + mock_res1.message.get_text_content.return_value = "Step 1 Prompt" + + mock_res2 = MagicMock() + mock_res2.message.get_text_content.return_value = '"var1"' + + # Step 3 fails + mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, InvokeError("Step 3 Failed")] + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate conversation opener" in result["error"] + + def test_generate_rule_config_with_variable_exception(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + # Mock any step to throw Exception + mock_model_instance.invoke_llm.side_effect = Exception("Unexpected multi-step error") + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to handle unexpected exception" in result["error"] + assert "Unexpected multi-step error" in result["error"] + + def test_generate_code_python_success(self, mock_model_instance, model_config_entity): + payload = RuleCodeGeneratePayload( + instruction="print hello", code_language="python", model_config=model_config_entity + ) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "print('hello')" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_code("tenant_id", payload) + assert result["code"] == "print('hello')" + assert result["language"] == "python" + + def test_generate_code_javascript_success(self, mock_model_instance, model_config_entity): + payload = RuleCodeGeneratePayload( + instruction="console log hello", code_language="javascript", model_config=model_config_entity + ) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "console.log('hello')" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_code("tenant_id", payload) + assert result["code"] == "console.log('hello')" + assert result["language"] == "javascript" + + def test_generate_code_invoke_error(self, mock_model_instance, model_config_entity): + payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity) + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed") + + result = LLMGenerator.generate_code("tenant_id", payload) + assert "Failed to generate code" in result["error"] + + def test_generate_code_exception(self, mock_model_instance, model_config_entity): + payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity) + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + + result = LLMGenerator.generate_code("tenant_id", payload) + assert "An unexpected error occurred" in result["error"] + + def test_generate_qa_document_success(self, mock_model_instance): + mock_response = MagicMock(spec=LLMResult) + mock_response.message = MagicMock() + mock_response.message.get_text_content.return_value = "QA Document Content" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_qa_document("tenant_id", "query", "English") + assert result == "QA Document Content" + + def test_generate_qa_document_type_error(self, mock_model_instance): + mock_model_instance.invoke_llm.return_value = "Not an LLMResult" + + with pytest.raises(TypeError, match="Expected LLMResult when stream=False"): + LLMGenerator.generate_qa_document("tenant_id", "query", "English") + + def test_generate_structured_output_success(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"type": "object", "properties": {}}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + parsed_output = json.loads(result["output"]) + assert parsed_output["type"] == "object" + assert result["error"] == "" + + def test_generate_structured_output_json_repair(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "{'type': 'object'}" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + parsed_output = json.loads(result["output"]) + assert parsed_output["type"] == "object" + + def test_generate_structured_output_not_dict_or_list(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "true" # parsed as bool + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + assert "An unexpected error occurred" in result["error"] + assert "Failed to parse structured output" in result["error"] + + def test_generate_structured_output_invoke_error(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity) + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed") + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + assert "Failed to generate JSON Schema" in result["error"] + + def test_generate_structured_output_exception(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity) + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + assert "An unexpected error occurred" in result["error"] + + def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + + # Mock __instruction_modify_common call via invoke_llm + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "prompt"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert result == {"modified": "prompt"} + + def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + last_run = MagicMock() + last_run.query = "q" + last_run.answer = "a" + last_run.error = "e" + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "prompt"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert result == {"modified": "prompt"} + + def test_instruction_modify_workflow_app_not_found(self): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = None + with pytest.raises(ValueError, match="App not found."): + LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock()) + + def test_instruction_modify_workflow_no_workflow(self): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + with pytest.raises(ValueError, match="Workflow not found for the given app model."): + LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", workflow_service) + + def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + + last_run = MagicMock() + last_run.node_type = "llm" + last_run.status = "s" + last_run.error = "e" + # Return regular values, not Mocks + last_run.execution_metadata_dict = {"agent_log": [{"status": "s", "error": "e", "data": {}}]} + last_run.load_full_inputs.return_value = {"in": "val"} + + workflow_service.get_node_last_run.return_value = last_run + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "workflow"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + assert result == {"modified": "workflow"} + + def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + workflow_service.get_node_last_run.return_value = None + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "fallback"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + assert result == {"modified": "fallback"} + + def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + # Cause exception in node_type logic + workflow.graph_dict = {"graph": {"nodes": []}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + workflow_service.get_node_last_run.return_value = None + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "fallback"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + assert result == {"modified": "fallback"} + + def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + + last_run = MagicMock() + last_run.node_type = "llm" + last_run.status = "s" + last_run.error = "e" + # Return regular empty list, not a Mock + last_run.execution_metadata_dict = {"agent_log": []} + last_run.load_full_inputs.return_value = {} + + workflow_service.get_node_last_run.return_value = last_run + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "workflow"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + assert result == {"modified": "workflow"} + + def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity): + # Testing placeholders replacement via instruction_modify_legacy for convenience + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"ok": true}' + mock_model_instance.invoke_llm.return_value = mock_response + + instruction = "Test {{#last_run#}} and {{#current#}} and {{#error_message#}}" + LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current_val", instruction, model_config_entity, "ideal" + ) + + # Verify the call to invoke_llm contains replaced instruction + args, kwargs = mock_model_instance.invoke_llm.call_args + prompt_messages = kwargs["prompt_messages"] + user_msg = prompt_messages[1].content + user_msg_dict = json.loads(user_msg) + assert "null" in user_msg_dict["instruction"] # because last_run is None and current is current_val etc. + assert "current_val" in user_msg_dict["instruction"] + + def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "No braces here" + mock_model_instance.invoke_llm.return_value = mock_response + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert "An unexpected error occurred" in result["error"] + assert "Could not find a valid JSON object" in result["error"] + + def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "[1, 2, 3]" + mock_model_instance.invoke_llm.return_value = mock_response + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + # The exception message is "Expected a JSON object, but got list" + assert "An unexpected error occurred" in result["error"] + + def test_instruction_modify_common_other_node_type(self, mock_model_instance, model_config_entity): + with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + instance = MagicMock() + mock_manager.return_value.get_model_instance.return_value = instance + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"ok": true}' + instance.invoke_llm.return_value = mock_response + + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + workflow_service.get_node_last_run.return_value = None + + LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + + def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed") + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert "Failed to generate code" in result["error"] + + def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert "An unexpected error occurred" in result["error"] + + def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "No JSON here" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert "An unexpected error occurred" in result["error"] diff --git a/api/tests/unit_tests/core/logging/test_filters.py b/api/tests/unit_tests/core/logging/test_filters.py index 7c2767266f..a8b186ac8a 100644 --- a/api/tests/unit_tests/core/logging/test_filters.py +++ b/api/tests/unit_tests/core/logging/test_filters.py @@ -82,6 +82,68 @@ class TestTraceContextFilter: assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2" assert log_record.span_id == "051581bf3bb55c45" + def test_otel_context_invalid_trace_id(self, log_record): + from core.logging.filters import TraceContextFilter + + mock_span = mock.MagicMock() + mock_context = mock.MagicMock() + mock_context.trace_id = 0 + mock_context.is_valid = True + mock_span.get_span_context.return_value = mock_context + + # Use mocks for base context to ensure we can test the fallback + with ( + mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span), + mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), + mock.patch("core.logging.filters.get_trace_id", return_value=""), + ): + filter = TraceContextFilter() + filter.filter(log_record) + assert log_record.trace_id == "" + + def test_otel_context_invalid_span_id(self, log_record): + from core.logging.filters import TraceContextFilter + + mock_span = mock.MagicMock() + mock_context = mock.MagicMock() + mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2 + mock_context.span_id = 0 + mock_context.is_valid = True + mock_span.get_span_context.return_value = mock_context + + with ( + mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span), + mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), + mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0), + ): + filter = TraceContextFilter() + filter.filter(log_record) + assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2" + assert log_record.span_id == "" + + def test_otel_context_span_none(self, log_record): + from core.logging.filters import TraceContextFilter + + with ( + mock.patch("opentelemetry.trace.get_current_span", return_value=None), + mock.patch("core.logging.filters.get_trace_id", return_value=""), + ): + filter = TraceContextFilter() + filter.filter(log_record) + assert log_record.trace_id == "" + + def test_otel_context_exception(self, log_record): + from core.logging.filters import TraceContextFilter + + # Trigger exception in OTEL block + with ( + mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception), + mock.patch("core.logging.filters.get_trace_id", return_value=""), + ): + filter = TraceContextFilter() + filter.filter(log_record) + assert log_record.trace_id == "" + class TestIdentityContextFilter: def test_sets_empty_identity_without_request_context(self, log_record): @@ -114,3 +176,119 @@ class TestIdentityContextFilter: result = filter.filter(log_record) assert result is True assert log_record.tenant_id == "" + + def test_sets_empty_identity_unauthenticated(self, log_record): + from core.logging.filters import IdentityContextFilter + + mock_user = mock.MagicMock() + mock_user.is_authenticated = False + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + assert log_record.user_id == "" + + def test_sets_identity_for_account(self, log_record): + from core.logging.filters import IdentityContextFilter + + class MockAccount: + pass + + mock_user = MockAccount() + mock_user.id = "account_id" + mock_user.current_tenant_id = "tenant_id" + mock_user.is_authenticated = True + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("models.Account", MockAccount), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + + assert log_record.tenant_id == "tenant_id" + assert log_record.user_id == "account_id" + assert log_record.user_type == "account" + + def test_sets_identity_for_account_no_tenant(self, log_record): + from core.logging.filters import IdentityContextFilter + + class MockAccount: + pass + + mock_user = MockAccount() + mock_user.id = "account_id" + mock_user.current_tenant_id = None + mock_user.is_authenticated = True + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("models.Account", MockAccount), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + + assert log_record.tenant_id == "" + assert log_record.user_id == "account_id" + assert log_record.user_type == "account" + + def test_sets_identity_for_end_user(self, log_record): + from core.logging.filters import IdentityContextFilter + + class MockEndUser: + pass + + class AnotherClass: + pass + + mock_user = MockEndUser() + mock_user.id = "end_user_id" + mock_user.tenant_id = "tenant_id" + mock_user.type = "custom_type" + mock_user.is_authenticated = True + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("models.model.EndUser", MockEndUser), + mock.patch("models.Account", AnotherClass), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + + assert log_record.tenant_id == "tenant_id" + assert log_record.user_id == "end_user_id" + assert log_record.user_type == "custom_type" + + def test_sets_identity_for_end_user_default_type(self, log_record): + from core.logging.filters import IdentityContextFilter + + class MockEndUser: + pass + + class AnotherClass: + pass + + mock_user = MockEndUser() + mock_user.id = "end_user_id" + mock_user.tenant_id = "tenant_id" + mock_user.type = None + mock_user.is_authenticated = True + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("models.model.EndUser", MockEndUser), + mock.patch("models.Account", AnotherClass), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + + assert log_record.tenant_id == "tenant_id" + assert log_record.user_id == "end_user_id" + assert log_record.user_type == "end_user" diff --git a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py index 60f37b6de0..abf3c60fe0 100644 --- a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py +++ b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py @@ -1,27 +1,39 @@ """Unit tests for MCP OAuth authentication flow.""" +import json from unittest.mock import Mock, patch +import httpx import pytest +from pydantic import ValidationError from core.entities.mcp_provider import MCPProviderEntity +from core.helper import ssrf_proxy from core.mcp.auth.auth_flow import ( OAUTH_STATE_EXPIRY_SECONDS, OAUTH_STATE_REDIS_KEY_PREFIX, OAuthCallbackState, _create_secure_redis_state, + _parse_token_response, _retrieve_redis_state, auth, + build_oauth_authorization_server_metadata_discovery_urls, + build_protected_resource_metadata_discovery_urls, check_support_resource_discovery, + client_credentials_flow, + discover_oauth_authorization_server_metadata, discover_oauth_metadata, + discover_protected_resource_metadata, exchange_authorization, generate_pkce_challenge, + get_effective_scope, handle_callback, refresh_authorization, register_client, start_authorization, ) from core.mcp.entities import AuthActionType, AuthResult +from core.mcp.error import MCPRefreshTokenError from core.mcp.types import ( LATEST_PROTOCOL_VERSION, OAuthClientInformation, @@ -764,3 +776,555 @@ class TestAuthOrchestration: auth(mock_provider, authorization_code="auth-code") assert "Existing OAuth client information is required" in str(exc_info.value) + + def test_generate_pkce_challenge(self): + verifier, challenge = generate_pkce_challenge() + assert verifier + assert challenge + assert "=" not in verifier + assert "=" not in challenge + + def test_build_protected_resource_metadata_discovery_urls(self): + # Case 1: WWW-Auth URL provided + urls = build_protected_resource_metadata_discovery_urls( + "https://auth.example.com/prm", "https://api.example.com" + ) + assert "https://auth.example.com/prm" in urls + assert "https://api.example.com/.well-known/oauth-protected-resource" in urls + + # Case 2: No WWW-Auth URL, with path + urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com/v1") + assert "https://api.example.com/.well-known/oauth-protected-resource/v1" in urls + assert "https://api.example.com/.well-known/oauth-protected-resource" in urls + + # Case 3: No path + urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com") + assert urls == ["https://api.example.com/.well-known/oauth-protected-resource"] + + def test_build_oauth_authorization_server_metadata_discovery_urls(self): + # Case 1: with auth_server_url + urls = build_oauth_authorization_server_metadata_discovery_urls( + "https://auth.example.com", "https://api.example.com" + ) + assert "https://auth.example.com/.well-known/oauth-authorization-server" in urls + assert "https://auth.example.com/.well-known/openid-configuration" in urls + + # Case 2: with path + urls = build_oauth_authorization_server_metadata_discovery_urls(None, "https://api.example.com/tenant") + assert "https://api.example.com/.well-known/oauth-authorization-server/tenant" in urls + assert "https://api.example.com/tenant/.well-known/openid-configuration" in urls + + @patch("core.helper.ssrf_proxy.get") + def test_discover_protected_resource_metadata(self, mock_get): + # Success + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "resource": "https://api.example.com", + "authorization_servers": ["https://auth"], + } + mock_get.return_value = mock_response + result = discover_protected_resource_metadata(None, "https://api.example.com") + assert result is not None + assert result.resource == "https://api.example.com" + + # 404 then Success + res404 = Mock() + res404.status_code = 404 + mock_get.side_effect = [res404, mock_response] + result = discover_protected_resource_metadata(None, "https://api.example.com/path") + assert result is not None + assert result.resource == "https://api.example.com" + + # Error handling + mock_get.side_effect = httpx.RequestError("Error") + result = discover_protected_resource_metadata(None, "https://api.example.com") + assert result is None + + @patch("core.helper.ssrf_proxy.get") + def test_discover_oauth_authorization_server_metadata(self, mock_get): + # Success + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "authorization_endpoint": "https://auth.example.com/auth", + "token_endpoint": "https://auth.example.com/token", + "response_types_supported": ["code"], + } + mock_get.return_value = mock_response + result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") + assert result is not None + assert result.authorization_endpoint == "https://auth.example.com/auth" + + # 404 + res404 = Mock() + res404.status_code = 404 + mock_get.side_effect = [res404, mock_response] + result = discover_oauth_authorization_server_metadata(None, "https://api.example.com/tenant") + assert result is not None + assert result.authorization_endpoint == "https://auth.example.com/auth" + + # ValidationError + mock_response.json.return_value = {"invalid": "data"} + mock_get.side_effect = None + mock_get.return_value = mock_response + result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") + assert result is None + + def test_get_effective_scope(self): + prm = ProtectedResourceMetadata( + resource="https://api.example.com", + authorization_servers=["https://auth"], + scopes_supported=["read", "write"], + ) + asm = OAuthMetadata( + authorization_endpoint="https://auth.example.com/auth", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + scopes_supported=["openid", "profile"], + ) + + # 1. WWW-Auth priority + assert get_effective_scope("scope1", prm, asm, "client") == "scope1" + # 2. PRM priority + assert get_effective_scope(None, prm, asm, "client") == "read write" + # 3. ASM priority + assert get_effective_scope(None, None, asm, "client") == "openid profile" + # 4. Client configured + assert get_effective_scope(None, None, None, "client") == "client" + + @patch("core.mcp.auth.auth_flow.redis_client") + def test_redis_state_management(self, mock_redis): + state_data = OAuthCallbackState( + provider_id="p1", + tenant_id="t1", + server_url="https://api", + metadata=None, + client_information=OAuthClientInformation(client_id="c1"), + code_verifier="cv", + redirect_uri="https://re", + ) + + # Create + state_key = _create_secure_redis_state(state_data) + assert state_key + mock_redis.setex.assert_called_once() + + # Retrieve Success + mock_redis.get.return_value = state_data.model_dump_json() + retrieved = _retrieve_redis_state(state_key) + assert retrieved.provider_id == "p1" + mock_redis.delete.assert_called_once() + + # Retrieve Failure - Not found + mock_redis.get.return_value = None + with pytest.raises(ValueError, match="expired or does not exist"): + _retrieve_redis_state("absent") + + # Retrieve Failure - Invalid JSON + mock_redis.get.return_value = "invalid" + with pytest.raises(ValueError, match="Invalid state parameter"): + _retrieve_redis_state("invalid") + + @patch("core.mcp.auth.auth_flow._retrieve_redis_state") + @patch("core.mcp.auth.auth_flow.exchange_authorization") + def test_handle_callback(self, mock_exchange, mock_retrieve): + state = Mock(spec=OAuthCallbackState) + state.server_url = "https://api" + state.metadata = None + state.client_information = Mock() + state.code_verifier = "cv" + state.redirect_uri = "https://re" + mock_retrieve.return_value = state + + tokens = Mock(spec=OAuthTokens) + mock_exchange.return_value = tokens + + s, t = handle_callback("key", "code") + assert s == state + assert t == tokens + + @patch("core.helper.ssrf_proxy.get") + def test_check_support_resource_discovery(self, mock_get): + # Case 1: authorization_servers (plural) + res = Mock() + res.status_code = 200 + res.json.return_value = {"authorization_servers": ["https://auth1"]} + mock_get.return_value = res + supported, url = check_support_resource_discovery("https://api") + assert supported is True + assert url == "https://auth1" + + # Case 2: authorization_server_url (singular alias) + res.json.return_value = {"authorization_server_url": ["https://auth2"]} + supported, url = check_support_resource_discovery("https://api") + assert supported is True + assert url == "https://auth2" + + # Case 3: Missing fields + res.json.return_value = {"nothing": []} + supported, url = check_support_resource_discovery("https://api") + assert supported is False + + # Case 4: 404 + res.status_code = 404 + supported, url = check_support_resource_discovery("https://api") + assert supported is False + + # Case 5: RequestError + mock_get.side_effect = httpx.RequestError("Error") + supported, url = check_support_resource_discovery("https://api") + assert supported is False + + def test_discover_oauth_metadata(self): + with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm: + with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm: + mock_prm.return_value = ProtectedResourceMetadata( + resource="https://api", authorization_servers=["https://auth"] + ) + mock_asm.return_value = Mock(spec=OAuthMetadata) + + asm, prm, hint = discover_oauth_metadata("https://api") + assert asm == mock_asm.return_value + assert prm == mock_prm.return_value + mock_asm.assert_called_with("https://auth", "https://api", None) + + def test_start_authorization(self): + metadata = OAuthMetadata( + authorization_endpoint="https://auth/authorize", + token_endpoint="https://auth/token", + response_types_supported=["code"], + ) + client_info = OAuthClientInformation(client_id="c1") + + with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create: + mock_create.return_value = "state-key" + + # Success with scope + url, verifier = start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1", "read") + assert "scope=read" in url + assert "state=state-key" in url + + # Success without metadata + url, verifier = start_authorization("https://api", None, client_info, "https://re", "p1", "t1") + assert "https://api/authorize" in url + + # Failure: incompatible auth server + metadata.response_types_supported = ["implicit"] + with pytest.raises(ValueError, match="Incompatible auth server"): + start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1") + + def test_parse_token_response(self): + # Case 1: JSON + res = Mock() + res.headers = {"content-type": "application/json"} + res.json.return_value = {"access_token": "at", "token_type": "Bearer"} + tokens = _parse_token_response(res) + assert tokens.access_token == "at" + + # Case 2: Form-urlencoded + res.headers = {"content-type": "application/x-www-form-urlencoded"} + res.text = "access_token=at2&token_type=Bearer" + tokens = _parse_token_response(res) + assert tokens.access_token == "at2" + + # Case 3: No content-type, but JSON + res.headers = {} + res.json.return_value = {"access_token": "at3", "token_type": "Bearer"} + tokens = _parse_token_response(res) + assert tokens.access_token == "at3" + + # Case 4: No content-type, not JSON, but Form + res.json.side_effect = json.JSONDecodeError("msg", "doc", 0) + res.text = "access_token=at4&token_type=Bearer" + tokens = _parse_token_response(res) + assert tokens.access_token == "at4" + + # Case 5: Validation Error fallback + res.json.side_effect = ValidationError.from_exception_data("error", []) + res.text = "access_token=at5&token_type=Bearer" + tokens = _parse_token_response(res) + assert tokens.access_token == "at5" + + @patch("core.helper.ssrf_proxy.post") + def test_exchange_authorization(self, mock_post): + client_info = OAuthClientInformation(client_id="c1", client_secret="s1") + metadata = OAuthMetadata( + authorization_endpoint="https://auth/authorize", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + + # Success + res = Mock() + res.is_success = True + res.headers = {"content-type": "application/json"} + res.json.return_value = {"access_token": "at", "token_type": "Bearer"} + mock_post.return_value = res + + tokens = exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re") + assert tokens.access_token == "at" + + # Failure: Unsupported grant type + metadata.grant_types_supported = ["client_credentials"] + with pytest.raises(ValueError, match="Incompatible auth server"): + exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re") + + # Failure: HTTP error + metadata.grant_types_supported = ["authorization_code"] + res.is_success = False + res.status_code = 400 + with pytest.raises(ValueError, match="Token exchange failed"): + exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re") + + @patch("core.helper.ssrf_proxy.post") + def test_refresh_authorization(self, mock_post): + # Case 1: with client_secret + client_info = OAuthClientInformation(client_id="c1", client_secret="s1") + + # Success + res = Mock() + res.is_success = True + res.headers = {"content-type": "application/json"} + res.json.return_value = {"access_token": "at_new", "token_type": "Bearer"} + mock_post.return_value = res + + tokens = refresh_authorization("https://api", None, client_info, "rt") + assert tokens.access_token == "at_new" + assert mock_post.call_args[1]["data"]["client_secret"] == "s1" + + # Failure: MaxRetriesExceededError + mock_post.side_effect = ssrf_proxy.MaxRetriesExceededError("Too many retries") + with pytest.raises(MCPRefreshTokenError): + refresh_authorization("https://api", None, client_info, "rt") + + # Failure: HTTP error + mock_post.side_effect = None + res.is_success = False + res.text = "error_msg" + with pytest.raises(MCPRefreshTokenError, match="error_msg"): + refresh_authorization("https://api", None, client_info, "rt") + + # Failure: Incompatible metadata + metadata = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + with pytest.raises(ValueError, match="Incompatible auth server"): + refresh_authorization("https://api", metadata, client_info, "rt") + + @patch("core.helper.ssrf_proxy.post") + def test_client_credentials_flow(self, mock_post): + client_info = OAuthClientInformation(client_id="c1", client_secret="s1") + + # Success with secret + res = Mock() + res.is_success = True + res.headers = {"content-type": "application/json"} + res.json.return_value = {"access_token": "at_cc", "token_type": "Bearer"} + mock_post.return_value = res + + tokens = client_credentials_flow("https://api", None, client_info, "read") + assert tokens.access_token == "at_cc" + args, kwargs = mock_post.call_args + assert "Authorization" in kwargs["headers"] + + # Success without secret + client_info_no_secret = OAuthClientInformation(client_id="c2") + tokens = client_credentials_flow("https://api", None, client_info_no_secret) + args, kwargs = mock_post.call_args + assert kwargs["data"]["client_id"] == "c2" + + # Failure: Incompatible metadata + metadata = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + with pytest.raises(ValueError, match="Incompatible auth server"): + client_credentials_flow("https://api", metadata, client_info) + + # Failure: HTTP error + res.is_success = False + res.status_code = 401 + res.text = "Unauthorized" + with pytest.raises(ValueError, match="Client credentials token request failed"): + client_credentials_flow("https://api", None, client_info) + + @patch("core.helper.ssrf_proxy.post") + def test_register_client(self, mock_post): + # Case 1: Success with metadata + metadata = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + registration_endpoint="https://auth/register", + response_types_supported=["code"], + ) + client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://re"]) + + res = Mock() + res.is_success = True + res.json.return_value = { + "client_id": "c_new", + "client_secret": "s_new", + "client_name": "Dify", + "redirect_uris": ["https://re"], + } + mock_post.return_value = res + + info = register_client("https://api", metadata, client_metadata) + assert info.client_id == "c_new" + + # Case 2: Success without metadata + info = register_client("https://api", None, client_metadata) + assert mock_post.call_args[0][0] == "https://api/register" + + # Case 3: Metadata provided but no endpoint + metadata.registration_endpoint = None + with pytest.raises(ValueError, match="does not support dynamic client registration"): + register_client("https://api", metadata, client_metadata) + + # Failure: HTTP + res.is_success = False + res.raise_for_status = Mock() + res.status_code = 400 + # If is_success is false, it should call raise_for_status + register_client("https://api", None, client_metadata) + res.raise_for_status.assert_called_once() + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_orchestration_failures(self, mock_discover): + provider = Mock(spec=MCPProviderEntity) + provider.decrypt_server_url.return_value = "https://api" + provider.id = "p1" + provider.tenant_id = "t1" + + # Case 1: No server metadata + mock_discover.return_value = (None, None, None) + with pytest.raises(ValueError, match="Failed to discover OAuth metadata"): + auth(provider) + + # Case 2: No client info, exchange code provided + asm = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + ) + mock_discover.return_value = (asm, None, None) + provider.retrieve_client_information.return_value = None + with pytest.raises(ValueError, match="Existing OAuth client information is required"): + auth(provider, authorization_code="code") + + # Case 3: CLIENT_CREDENTIALS but client must provide info + asm.grant_types_supported = ["client_credentials"] + with pytest.raises(ValueError, match="requires client_id and client_secret"): + auth(provider) + + # Case 4: Client registration fails + asm.grant_types_supported = ["authorization_code"] + with patch("core.mcp.auth.auth_flow.register_client") as mock_reg: + mock_reg.side_effect = httpx.RequestError("Reg failed") + with pytest.raises(ValueError, match="Could not register OAuth client"): + auth(provider) + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_orchestration_client_credentials(self, mock_discover): + provider = Mock(spec=MCPProviderEntity) + provider.decrypt_server_url.return_value = "https://api" + provider.id = "p1" + provider.tenant_id = "t1" + provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1", client_secret="s1") + provider.decrypt_credentials.return_value = {"scope": "read"} + + asm = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["client_credentials"], + ) + mock_discover.return_value = (asm, None, None) + + with patch("core.mcp.auth.auth_flow.client_credentials_flow") as mock_cc: + mock_cc.return_value = OAuthTokens(access_token="at_cc", token_type="Bearer") + + result = auth(provider) + assert result.response == {"result": "success"} + assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS + assert result.actions[0].data["grant_type"] == "client_credentials" + + # Failure in CC flow + mock_cc.side_effect = ValueError("CC Failed") + with pytest.raises(ValueError, match="Client credentials flow failed"): + auth(provider) + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_orchestration_authorization_code(self, mock_discover): + provider = Mock(spec=MCPProviderEntity) + provider.decrypt_server_url.return_value = "https://api" + provider.id = "p1" + provider.tenant_id = "t1" + provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1") + provider.decrypt_credentials.return_value = {} + + asm = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + mock_discover.return_value = (asm, None, None) + + # Case 1: Exchange code + with patch("core.mcp.auth.auth_flow._retrieve_redis_state") as mock_retrieve: + state = Mock(spec=OAuthCallbackState) + state.code_verifier = "cv" + state.redirect_uri = "https://re" + mock_retrieve.return_value = state + + with patch("core.mcp.auth.auth_flow.exchange_authorization") as mock_exchange: + mock_exchange.return_value = OAuthTokens(access_token="at_code", token_type="Bearer") + + # Success + result = auth(provider, authorization_code="code", state_param="sp") + assert result.response == {"result": "success"} + + # Missing state_param + with pytest.raises(ValueError, match="State parameter is required"): + auth(provider, authorization_code="code") + + # Missing verifier in state + state.code_verifier = None + with pytest.raises(ValueError, match="Missing code_verifier"): + auth(provider, authorization_code="code", state_param="sp") + + # Invalid state + mock_retrieve.side_effect = ValueError("Invalid") + with pytest.raises(ValueError, match="Invalid state parameter"): + auth(provider, authorization_code="code", state_param="sp") + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_orchestration_refresh_failure(self, mock_discover): + provider = Mock(spec=MCPProviderEntity) + provider.decrypt_server_url.return_value = "https://api" + provider.id = "p1" + provider.tenant_id = "t1" + provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1") + provider.decrypt_credentials.return_value = {} + provider.retrieve_tokens.return_value = OAuthTokens(access_token="at", token_type="Bearer", refresh_token="rt") + + asm = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + mock_discover.return_value = (asm, None, None) + + with patch("core.mcp.auth.auth_flow.refresh_authorization") as mock_refresh: + mock_refresh.side_effect = ValueError("Refresh Failed") + with pytest.raises(ValueError, match="Could not refresh OAuth tokens"): + auth(provider) diff --git a/api/tests/unit_tests/core/mcp/client/test_sse.py b/api/tests/unit_tests/core/mcp/client/test_sse.py index 490a647025..e6eeb6cd59 100644 --- a/api/tests/unit_tests/core/mcp/client/test_sse.py +++ b/api/tests/unit_tests/core/mcp/client/test_sse.py @@ -322,3 +322,475 @@ def test_sse_client_concurrent_access(): assert len(received_messages) == 10 for i in range(10): assert f"message_{i}" in received_messages + + +class TestStatusClasses: + """Tests for _StatusReady and _StatusError data containers.""" + + def test_status_ready_stores_endpoint(self): + from core.mcp.client.sse_client import _StatusReady + + status = _StatusReady("http://example.com/messages/") + assert status.endpoint_url == "http://example.com/messages/" + + def test_status_error_stores_exception(self): + from core.mcp.client.sse_client import _StatusError + + exc = ValueError("bad endpoint") + status = _StatusError(exc) + assert status.exc is exc + + +class TestSSETransportInit: + """Tests for SSETransport default and explicit init values.""" + + def test_defaults(self): + from core.mcp.client.sse_client import SSETransport + + t = SSETransport("http://example.com/sse") + assert t.url == "http://example.com/sse" + assert t.headers == {} + assert t.timeout == 5.0 + assert t.sse_read_timeout == 60.0 + assert t.endpoint_url is None + assert t.event_source is None + + def test_explicit_headers_not_mutated(self): + from core.mcp.client.sse_client import SSETransport + + hdrs = {"X-Foo": "bar"} + t = SSETransport("http://example.com/sse", headers=hdrs) + assert t.headers is hdrs + + +class TestHandleEndpointEvent: + """Tests for SSETransport._handle_endpoint_event covering the invalid-origin branch.""" + + def test_invalid_origin_puts_status_error(self): + from core.mcp.client.sse_client import SSETransport, _StatusError + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() + + # Provide a full URL with a different origin so urljoin keeps it as-is + transport._handle_endpoint_event("http://evil.com/messages/", status_queue) + + result = status_queue.get_nowait() + assert isinstance(result, _StatusError) + assert "does not match" in str(result.exc) + + def test_valid_origin_puts_status_ready(self): + from core.mcp.client.sse_client import SSETransport, _StatusReady + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() + + transport._handle_endpoint_event("/messages/?session_id=abc", status_queue) + + result = status_queue.get_nowait() + assert isinstance(result, _StatusReady) + assert "example.com" in result.endpoint_url + + +class TestHandleSSEEvent: + """Tests for SSETransport._handle_sse_event covering all match branches.""" + + def _make_sse(self, event_type: str, data: str): + sse = Mock() + sse.event = event_type + sse.data = data + return sse + + def test_message_event_dispatched(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + read_queue: queue.Queue = queue.Queue() + status_queue: queue.Queue = queue.Queue() + + valid_msg = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + transport._handle_sse_event(self._make_sse("message", valid_msg), read_queue, status_queue) + + item = read_queue.get_nowait() + assert hasattr(item, "message") + + def test_unknown_event_logs_warning_and_does_nothing(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + read_queue: queue.Queue = queue.Queue() + status_queue: queue.Queue = queue.Queue() + + transport._handle_sse_event(self._make_sse("ping", "{}"), read_queue, status_queue) + + assert read_queue.empty() + assert status_queue.empty() + + +class TestSSEReader: + """Tests for SSETransport.sse_reader exception branches.""" + + def test_read_error_closes_cleanly(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + read_queue: queue.Queue = queue.Queue() + status_queue: queue.Queue = queue.Queue() + + event_source = Mock() + event_source.iter_sse.side_effect = httpx.ReadError("connection reset") + + transport.sse_reader(event_source, read_queue, status_queue) + + # Finally block always puts None as sentinel + sentinel = read_queue.get_nowait() + assert sentinel is None + + def test_generic_exception_puts_exc_then_none(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + read_queue: queue.Queue = queue.Queue() + status_queue: queue.Queue = queue.Queue() + + boom = RuntimeError("unexpected!") + event_source = Mock() + event_source.iter_sse.side_effect = boom + + transport.sse_reader(event_source, read_queue, status_queue) + + exc_item = read_queue.get_nowait() + assert exc_item is boom + + sentinel = read_queue.get_nowait() + assert sentinel is None + + +class TestSendMessage: + """Tests for SSETransport._send_message.""" + + def _make_session_message(self): + msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + msg = types.JSONRPCMessage.model_validate_json(msg_json) + return types.SessionMessage(msg) + + def test_sends_post_and_raises_for_status(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + + mock_response = Mock() + mock_response.status_code = 200 + mock_client = Mock() + mock_client.post.return_value = mock_response + + session_msg = self._make_session_message() + transport._send_message(mock_client, "http://example.com/messages/", session_msg) + + mock_client.post.assert_called_once() + mock_response.raise_for_status.assert_called_once() + + +class TestPostWriter: + """Tests for SSETransport.post_writer exception branches.""" + + def _make_session_message(self): + msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + msg = types.JSONRPCMessage.model_validate_json(msg_json) + return types.SessionMessage(msg) + + def test_none_message_exits_loop(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + write_queue.put(None) # Signal shutdown immediately + + mock_client = Mock() + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + # Should put final None sentinel + sentinel = write_queue.get_nowait() + assert sentinel is None + + def test_exception_in_message_put_back_to_queue(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + + exc = ValueError("some error") + write_queue.put(exc) # Exception goes in first + write_queue.put(None) # Then shutdown signal + + mock_client = Mock() + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + # The exception should be re-queued, then None from loop exit, then None from finally + item1 = write_queue.get_nowait() + assert isinstance(item1, Exception) + + def test_read_error_shuts_down_cleanly(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + + session_msg = self._make_session_message() + write_queue.put(session_msg) + + mock_response = Mock() + mock_response.status_code = 200 + mock_client = Mock() + mock_client.post.side_effect = httpx.ReadError("connection dropped") + + # post_writer calls _send_message which calls client.post → ReadError propagates + # The ReadError is raised inside _send_message → propagates out of the while loop + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + # finally always puts None + sentinel = write_queue.get_nowait() + assert sentinel is None + + def test_generic_exception_puts_exc_in_queue(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + + session_msg = self._make_session_message() + write_queue.put(session_msg) + + mock_client = Mock() + boom = RuntimeError("boom") + mock_client.post.side_effect = boom + + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + exc_item = write_queue.get_nowait() + assert isinstance(exc_item, Exception) + + sentinel = write_queue.get_nowait() + assert sentinel is None + + def test_queue_empty_timeout_continues_loop(self): + """Cover the 'except queue.Empty: continue' branch (line 188) in post_writer.""" + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + + mock_client = Mock() + + # Patch queue.Queue.get so it raises Empty first, then returns None (shutdown) + call_count = {"n": 0} + original_get = write_queue.get + + def patched_get(*args, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + raise queue.Empty + + write_queue.get = patched_get # type: ignore[method-assign] + + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + # finally always puts None sentinel + sentinel = write_queue.get_nowait() + assert sentinel is None + assert call_count["n"] >= 2 # Empty on first, None on second (and possibly more retries) + + +class TestWaitForEndpoint: + """Tests for SSETransport._wait_for_endpoint edge cases.""" + + def test_raises_on_empty_queue(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() # empty + + with pytest.raises(ValueError, match="failed to get endpoint URL"): + transport._wait_for_endpoint(status_queue) + + def test_raises_status_error_exception(self): + from core.mcp.client.sse_client import SSETransport, _StatusError + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() + + exc = ValueError("malicious endpoint") + status_queue.put(_StatusError(exc)) + + with pytest.raises(ValueError, match="malicious endpoint"): + transport._wait_for_endpoint(status_queue) + + def test_raises_on_unknown_status_type(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() + + # Put an object that is neither _StatusReady nor _StatusError + status_queue.put("unexpected_value") + + with pytest.raises(ValueError, match="failed to get endpoint URL"): + transport._wait_for_endpoint(status_queue) + + +class TestSSEClientRuntimeError: + """Test sse_client context manager handles RuntimeError on close().""" + + def test_runtime_error_on_close_is_suppressed(self): + """Ensure RuntimeError raised by event_source.response.close() is caught.""" + test_url = "http://test.example/sse" + + class MockSSEEvent: + def __init__(self, event_type: str, data: str): + self.event = event_type + self.data = data + + endpoint_event = MockSSEEvent("endpoint", "/messages/?session_id=test-123") + + with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sc: + mock_client = Mock() + mock_cf.return_value.__enter__.return_value = mock_client + + mock_es = Mock() + mock_es.response.raise_for_status.return_value = None + mock_es.iter_sse.return_value = [endpoint_event] + # Make close() raise RuntimeError to exercise line 307-308 + mock_es.response.close.side_effect = RuntimeError("already closed") + mock_sc.return_value.__enter__.return_value = mock_es + + # Should NOT raise even though close() raises RuntimeError + with contextlib.suppress(Exception): + with sse_client(test_url) as (rq, wq): + pass + + +class TestStandaloneSendMessage: + """Tests for the module-level send_message() function.""" + + def _make_session_message(self): + msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + msg = types.JSONRPCMessage.model_validate_json(msg_json) + return types.SessionMessage(msg) + + def test_send_message_success(self): + from core.mcp.client.sse_client import send_message + + mock_response = Mock() + mock_response.status_code = 200 + mock_http_client = Mock() + mock_http_client.post.return_value = mock_response + + session_msg = self._make_session_message() + send_message(mock_http_client, "http://example.com/messages/", session_msg) + + mock_http_client.post.assert_called_once() + mock_response.raise_for_status.assert_called_once() + + def test_send_message_raises_on_http_error(self): + from core.mcp.client.sse_client import send_message + + mock_http_client = Mock() + mock_http_client.post.side_effect = httpx.ConnectError("refused") + + session_msg = self._make_session_message() + + with pytest.raises(httpx.ConnectError): + send_message(mock_http_client, "http://example.com/messages/", session_msg) + + def test_send_message_raises_for_status_failure(self): + from core.mcp.client.sse_client import send_message + + mock_response = Mock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not Found", request=Mock(), response=Mock(status_code=404) + ) + mock_http_client = Mock() + mock_http_client.post.return_value = mock_response + + session_msg = self._make_session_message() + + with pytest.raises(httpx.HTTPStatusError): + send_message(mock_http_client, "http://example.com/messages/", session_msg) + + +class TestReadMessages: + """Tests for the module-level read_messages() generator.""" + + def _make_mock_sse_event(self, event_type: str, data: str): + ev = Mock() + ev.event = event_type + ev.data = data + return ev + + def test_valid_message_event_yields_session_message(self): + from core.mcp.client.sse_client import read_messages + + valid_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + mock_sse_event = self._make_mock_sse_event("message", valid_json) + + mock_client = Mock() + mock_client.events.return_value = [mock_sse_event] + + results = list(read_messages(mock_client)) + assert len(results) == 1 + assert hasattr(results[0], "message") + + def test_invalid_json_yields_exception(self): + from core.mcp.client.sse_client import read_messages + + mock_sse_event = self._make_mock_sse_event("message", "{not valid json}") + + mock_client = Mock() + mock_client.events.return_value = [mock_sse_event] + + results = list(read_messages(mock_client)) + assert len(results) == 1 + assert isinstance(results[0], Exception) + + def test_non_message_event_is_skipped(self): + from core.mcp.client.sse_client import read_messages + + mock_sse_event = self._make_mock_sse_event("endpoint", "/messages/") + + mock_client = Mock() + mock_client.events.return_value = [mock_sse_event] + + results = list(read_messages(mock_client)) + # Non-message events produce no output + assert results == [] + + def test_outer_exception_yields_exc(self): + from core.mcp.client.sse_client import read_messages + + boom = RuntimeError("stream broken") + mock_client = Mock() + mock_client.events.side_effect = boom + + results = list(read_messages(mock_client)) + assert len(results) == 1 + assert results[0] is boom + + def test_multiple_events_mixed(self): + from core.mcp.client.sse_client import read_messages + + valid_json = '{"jsonrpc": "2.0", "id": 2, "result": {}}' + events = [ + self._make_mock_sse_event("endpoint", "/messages/"), + self._make_mock_sse_event("message", valid_json), + self._make_mock_sse_event("message", "{bad json}"), + ] + + mock_client = Mock() + mock_client.events.return_value = events + + results = list(read_messages(mock_client)) + # endpoint is skipped; 1 valid SessionMessage + 1 Exception + assert len(results) == 2 + assert hasattr(results[0], "message") + assert isinstance(results[1], Exception) diff --git a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py index 9a30a35a49..81f8da9a62 100644 --- a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py @@ -4,14 +4,39 @@ Tests for the StreamableHTTP client transport. Contains tests for only the client side of the StreamableHTTP transport. """ +import json import queue import threading import time +from contextlib import contextmanager +from datetime import timedelta from typing import Any -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch + +import httpx +import pytest +from httpx_sse import ServerSentEvent from core.mcp import types -from core.mcp.client.streamable_client import streamablehttp_client +from core.mcp.client.streamable_client import ( + LAST_EVENT_ID, + MCP_SESSION_ID, + RequestContext, + ResumptionError, + StreamableHTTPError, + StreamableHTTPTransport, + streamablehttp_client, +) +from core.mcp.types import ( + ClientMessageMetadata, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + SessionMessage, +) # Test constants SERVER_NAME = "test_streamable_http_server" @@ -448,3 +473,1169 @@ def test_streamablehttp_client_resumption_token_handling(): assert write_queue is not None except Exception: pass # Expected due to mocking + + +# ── helpers ─────────────────────────────────────────────────────────────────── + + +def _make_request_msg(method: str = "ping", req_id: int = 1) -> JSONRPCMessage: + return JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=req_id, method=method)) + + +def _make_response_msg(req_id: int = 1, result: dict | None = None) -> JSONRPCMessage: + return JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=req_id, result=result or {})) + + +def _make_error_msg(req_id: int = 1, code: int = -32600) -> JSONRPCMessage: + return JSONRPCMessage(root=JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=code, message="err"))) + + +def _make_notification_msg(method: str = "notifications/initialized") -> JSONRPCMessage: + return JSONRPCMessage(root=JSONRPCNotification(jsonrpc="2.0", method=method)) + + +def _make_sse_mock(event: str = "message", data: str = "", sse_id: str = "") -> ServerSentEvent: + # Use real ServerSentEvent since StreamableHTTPTransport requires its structure + return ServerSentEvent(event=event, data=data, id=sse_id, retry=None) + + +def _new_transport(url: str = "http://example.com/mcp", **kwargs) -> StreamableHTTPTransport: + return StreamableHTTPTransport(url, **kwargs) + + +# ── StreamableHTTPTransport.__init__ ───────────────────────────────────────── + + +class TestStreamableHTTPTransportInit: + def test_defaults(self): + t = _new_transport() + assert t.url == "http://example.com/mcp" + assert t.headers == {} + assert t.timeout == 30 + assert t.sse_read_timeout == 300 + assert t.session_id is None + assert t.stop_event is not None + assert t._active_responses == [] + + def test_timedelta_timeout_and_sse_read_timeout(self): + t = _new_transport(timeout=timedelta(seconds=10), sse_read_timeout=timedelta(seconds=120)) + assert t.timeout == 10.0 + assert t.sse_read_timeout == 120.0 + + def test_custom_headers_merged_into_request_headers(self): + t = _new_transport(headers={"Authorization": "Bearer tok"}) + assert t.request_headers["Authorization"] == "Bearer tok" + assert "Accept" in t.request_headers + assert "content-type" in t.request_headers + + +# ── _update_headers_with_session ───────────────────────────────────────────── + + +class TestUpdateHeadersWithSession: + def test_no_session_id_returns_copy_without_session_header(self): + t = _new_transport() + t.session_id = None + result = t._update_headers_with_session({"X-Foo": "bar"}) + assert result == {"X-Foo": "bar"} + assert MCP_SESSION_ID not in result + + def test_with_session_id_adds_header(self): + t = _new_transport() + t.session_id = "sess-abc" + result = t._update_headers_with_session({"X-Foo": "bar"}) + assert result[MCP_SESSION_ID] == "sess-abc" + assert result["X-Foo"] == "bar" + + +# ── _register_response / _unregister_response / close_active_responses ──────── + + +class TestResponseRegistry: + def test_register_and_unregister(self): + t = _new_transport() + resp = MagicMock(spec=httpx.Response) + t._register_response(resp) + assert resp in t._active_responses + t._unregister_response(resp) + assert resp not in t._active_responses + + def test_unregister_not_registered_does_not_raise(self): + t = _new_transport() + resp = MagicMock(spec=httpx.Response) + t._unregister_response(resp) # Should swallow ValueError silently + + def test_close_active_responses_calls_close(self): + t = _new_transport() + resp1 = MagicMock(spec=httpx.Response) + resp2 = MagicMock(spec=httpx.Response) + t._register_response(resp1) + t._register_response(resp2) + t.close_active_responses() + resp1.close.assert_called_once() + resp2.close.assert_called_once() + assert t._active_responses == [] + + def test_close_active_responses_swallows_runtime_error(self): + t = _new_transport() + resp = MagicMock(spec=httpx.Response) + resp.close.side_effect = RuntimeError("already closed") + t._register_response(resp) + t.close_active_responses() # Should not raise + + +# ── _is_initialization_request / _is_initialized_notification ──────────────── + + +class TestMessageClassifiers: + def test_is_initialization_request_true(self): + t = _new_transport() + assert t._is_initialization_request(_make_request_msg("initialize")) is True + + def test_is_initialization_request_false_other_method(self): + t = _new_transport() + assert t._is_initialization_request(_make_request_msg("tools/list")) is False + + def test_is_initialization_request_false_not_request(self): + t = _new_transport() + assert t._is_initialization_request(_make_response_msg()) is False + + def test_is_initialized_notification_true(self): + t = _new_transport() + assert t._is_initialized_notification(_make_notification_msg("notifications/initialized")) is True + + def test_is_initialized_notification_false_other_method(self): + t = _new_transport() + assert t._is_initialized_notification(_make_notification_msg("notifications/cancelled")) is False + + def test_is_initialized_notification_false_not_notification(self): + t = _new_transport() + assert t._is_initialized_notification(_make_request_msg("notifications/initialized")) is False + + +# ── _maybe_extract_session_id_from_response ─────────────────────────────────── + + +class TestMaybeExtractSessionIdNew: + def test_extracts_session_id_when_present(self): + t = _new_transport() + resp = MagicMock() + resp.headers = {MCP_SESSION_ID: "new-session-99"} + t._maybe_extract_session_id_from_response(resp) + assert t.session_id == "new-session-99" + + def test_no_session_id_header_leaves_none(self): + t = _new_transport() + resp = MagicMock() + resp.headers = MagicMock() + resp.headers.get = MagicMock(return_value=None) + t._maybe_extract_session_id_from_response(resp) + assert t.session_id is None + + +# ── _handle_sse_event ───────────────────────────────────────────────────────── + + +class TestHandleSseEventNew: + def test_message_event_response_returns_true(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("message", json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}})) + assert t._handle_sse_event(sse, q) is True + assert isinstance(q.get_nowait(), SessionMessage) + + def test_message_event_error_returns_true(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "bad"}}) + sse = _make_sse_mock("message", data) + assert t._handle_sse_event(sse, q) is True + + def test_message_event_notification_returns_false(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "method": "notifications/something"}) + sse = _make_sse_mock("message", data) + assert t._handle_sse_event(sse, q) is False + assert isinstance(q.get_nowait(), SessionMessage) + + def test_message_event_empty_data_returns_false(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("message", " ") + assert t._handle_sse_event(sse, q) is False + assert q.empty() + + def test_message_event_invalid_json_puts_exception(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("message", "{bad json}") + assert t._handle_sse_event(sse, q) is False + assert isinstance(q.get_nowait(), Exception) + + def test_message_event_replaces_original_request_id(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + sse = _make_sse_mock("message", data, sse_id="") + t._handle_sse_event(sse, q, original_request_id=999) + item = q.get_nowait() + assert isinstance(item, SessionMessage) + assert item.message.root.id == 999 + + def test_message_event_calls_resumption_callback_when_sse_id_present(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + sse = _make_sse_mock("message", data, sse_id="token-abc") + callback = MagicMock() + t._handle_sse_event(sse, q, resumption_callback=callback) + callback.assert_called_once_with("token-abc") + + def test_message_event_no_callback_when_no_sse_id(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + sse = _make_sse_mock("message", data, sse_id="") + callback = MagicMock() + t._handle_sse_event(sse, q, resumption_callback=callback) + callback.assert_not_called() + + def test_ping_event_returns_false(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("ping", "") + assert t._handle_sse_event(sse, q) is False + assert q.empty() + + def test_unknown_event_returns_false(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("custom_event", "{}") + assert t._handle_sse_event(sse, q) is False + assert q.empty() + + +# ── handle_get_stream ───────────────────────────────────────────────────────── + + +class TestHandleGetStreamNew: + def test_skips_when_no_session_id(self): + t = _new_transport() + t.session_id = None + q: queue.Queue = queue.Queue() + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + t.handle_get_stream(MagicMock(), q) + mock_connect.assert_not_called() + + def test_handles_messages_via_sse(self): + t = _new_transport() + t.session_id = "sess-1" + q: queue.Queue = queue.Queue() + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [mock_sse_event] + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.return_value.__enter__.return_value = mock_event_source + t.handle_get_stream(MagicMock(), q) + + assert isinstance(q.get_nowait(), SessionMessage) + + def test_stops_when_stop_event_set(self): + t = _new_transport() + t.session_id = "sess-1" + t.stop_event.set() + q: queue.Queue = queue.Queue() + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [mock_sse_event] + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.return_value.__enter__.return_value = mock_event_source + t.handle_get_stream(MagicMock(), q) + + assert q.empty() + + def test_exception_when_not_stopped_is_logged(self): + t = _new_transport() + t.session_id = "sess-1" + q: queue.Queue = queue.Queue() + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.side_effect = Exception("connection error") + t.handle_get_stream(MagicMock(), q) # Should not raise + + def test_exception_when_stopped_is_suppressed(self): + t = _new_transport() + t.session_id = "sess-1" + t.stop_event.set() + q: queue.Queue = queue.Queue() + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.side_effect = Exception("connection error") + t.handle_get_stream(MagicMock(), q) # Should not raise or log + + +# ── _handle_resumption_request ──────────────────────────────────────────────── + + +class TestHandleResumptionRequestNew: + def _make_ctx(self, transport, q, resumption_token="token-123", message=None) -> RequestContext: + if message is None: + message = _make_request_msg("tools/list", req_id=42) + session_msg = SessionMessage(message) + metadata = None + if resumption_token: + metadata = MagicMock(spec=ClientMessageMetadata) + metadata.resumption_token = resumption_token + metadata.on_resumption_token_update = MagicMock() + return RequestContext( + client=MagicMock(), + headers=transport.request_headers, + session_id=transport.session_id, + session_message=session_msg, + metadata=metadata, + server_to_client_queue=q, + sse_read_timeout=60, + ) + + def test_raises_resumption_error_without_token(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + metadata = MagicMock(spec=ClientMessageMetadata) + metadata.resumption_token = None + ctx = RequestContext( + client=MagicMock(), + headers=t.request_headers, + session_id=None, + session_message=SessionMessage(_make_request_msg()), + metadata=metadata, + server_to_client_queue=q, + sse_read_timeout=60, + ) + with pytest.raises(ResumptionError): + t._handle_resumption_request(ctx) + + def test_raises_resumption_error_without_metadata(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = RequestContext( + client=MagicMock(), + headers=t.request_headers, + session_id=None, + session_message=SessionMessage(_make_request_msg()), + metadata=None, + server_to_client_queue=q, + sse_read_timeout=60, + ) + with pytest.raises(ResumptionError): + t._handle_resumption_request(ctx) + + def test_sets_last_event_id_header(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q, resumption_token="resume-999") + + captured_headers: dict = {} + data = json.dumps({"jsonrpc": "2.0", "id": 42, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [mock_sse_event] + + def fake_connect(url, headers, **kwargs): + captured_headers.update(headers) + + @contextmanager + def _ctx(): + yield mock_event_source + + return _ctx() + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect", side_effect=fake_connect): + t._handle_resumption_request(ctx) + + assert captured_headers.get(LAST_EVENT_ID) == "resume-999" + + def test_stops_when_response_complete(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q, message=_make_request_msg("tools/list", 42)) + + data1 = json.dumps({"jsonrpc": "2.0", "id": 42, "result": {}}) + data2 = json.dumps({"jsonrpc": "2.0", "id": 43, "result": {}}) + sse1 = _make_sse_mock("message", data1) + sse2 = _make_sse_mock("message", data2) + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [sse1, sse2] + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.return_value.__enter__.return_value = mock_event_source + t._handle_resumption_request(ctx) + + # Only the first event was processed (loop breaks on completion) + assert q.qsize() == 1 + + def test_stops_when_stop_event_set(self): + t = _new_transport() + t.stop_event.set() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [mock_sse_event] + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.return_value.__enter__.return_value = mock_event_source + t._handle_resumption_request(ctx) + + assert q.empty() + + +# ── _handle_post_request ────────────────────────────────────────────────────── + + +class TestHandlePostRequestNew: + def _make_ctx(self, transport, q, message=None) -> RequestContext: + if message is None: + message = _make_request_msg("tools/list", 1) + return RequestContext( + client=MagicMock(), + headers=transport.request_headers, + session_id=transport.session_id, + session_message=SessionMessage(message), + metadata=None, + server_to_client_queue=q, + sse_read_timeout=60, + ) + + def _stream_ctx(self, mock_response): + @contextmanager + def _stream(*args, **kwargs): + yield mock_response + + return _stream + + def test_202_returns_immediately_no_queue(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + mock_resp = MagicMock() + mock_resp.status_code = 202 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + assert q.empty() + + def test_204_returns_immediately_no_queue(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + mock_resp = MagicMock() + mock_resp.status_code = 204 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + assert q.empty() + + def test_404_sends_session_terminated_error_for_request(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + msg = _make_request_msg("tools/list", 77) + ctx = self._make_ctx(t, q, message=msg) + mock_resp = MagicMock() + mock_resp.status_code = 404 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + item = q.get_nowait() + assert isinstance(item, SessionMessage) + assert isinstance(item.message.root, JSONRPCError) + assert item.message.root.id == 77 + + def test_404_for_notification_no_error_sent(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + msg = _make_notification_msg("some/notification") + ctx = self._make_ctx(t, q, message=msg) + mock_resp = MagicMock() + mock_resp.status_code = 404 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + assert q.empty() + + def test_json_response_puts_session_message(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + response_data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}).encode() + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.raise_for_status.return_value = None + mock_resp.read.return_value = response_data + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + assert isinstance(q.get_nowait(), SessionMessage) + + def test_json_response_invalid_json_puts_exception(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.raise_for_status.return_value = None + mock_resp.read.return_value = b"{bad json!" + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + assert isinstance(q.get_nowait(), Exception) + + def test_unexpected_content_type_puts_value_error(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "text/plain"} + mock_resp.raise_for_status.return_value = None + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + item = q.get_nowait() + assert isinstance(item, ValueError) + assert "Unexpected content type" in str(item) + + def test_initialization_request_extracts_session_id(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + msg = _make_request_msg("initialize", 1) + ctx = self._make_ctx(t, q, message=msg) + + response_data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}).encode() + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = MagicMock() + headers_dict = {"content-type": "application/json", MCP_SESSION_ID: "new-sid"} + mock_resp.headers.__getitem__ = lambda self, k: headers_dict[k] + mock_resp.headers.get = lambda k, default=None: headers_dict.get(k, default) + mock_resp.raise_for_status.return_value = None + mock_resp.read.return_value = response_data + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + assert t.session_id == "new-sid" + + def test_notification_skips_response_processing(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + msg = _make_notification_msg("notifications/something") + ctx = self._make_ctx(t, q, message=msg) + + response_data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}).encode() + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.raise_for_status.return_value = None + mock_resp.read.return_value = response_data + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + assert q.empty() + + def test_sse_response_handles_stream(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "text/event-stream"} + mock_resp.raise_for_status.return_value = None + ctx.client.stream = self._stream_ctx(mock_resp) + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [mock_sse_event] + MockEventSource.return_value = mock_es_instance + t._handle_post_request(ctx) + + assert isinstance(q.get_nowait(), SessionMessage) + + +# ── _handle_json_response ───────────────────────────────────────────────────── + + +class TestHandleJsonResponseNew: + def test_valid_json_puts_session_message(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}).encode() + mock_response = MagicMock() + mock_response.read.return_value = data + t._handle_json_response(mock_response, q) + assert isinstance(q.get_nowait(), SessionMessage) + + def test_invalid_json_puts_exception(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + mock_response = MagicMock() + mock_response.read.return_value = b"{ invalid }" + t._handle_json_response(mock_response, q) + assert isinstance(q.get_nowait(), Exception) + + +# ── _handle_sse_response ────────────────────────────────────────────────────── + + +class TestHandleSseResponseNew: + def _ctx(self, transport, q) -> RequestContext: + return RequestContext( + client=MagicMock(), + headers=transport.request_headers, + session_id=None, + session_message=SessionMessage(_make_request_msg()), + metadata=None, + server_to_client_queue=q, + sse_read_timeout=60, + ) + + def test_processes_sse_events(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [mock_sse_event] + MockEventSource.return_value = mock_es_instance + t._handle_sse_response(mock_response, ctx) + + assert isinstance(q.get_nowait(), SessionMessage) + + def test_stops_when_stop_event_set(self): + t = _new_transport() + t.stop_event.set() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [mock_sse_event] + MockEventSource.return_value = mock_es_instance + t._handle_sse_response(mock_response, ctx) + + assert q.empty() + + def test_stops_when_complete(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + + data1 = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + data2 = json.dumps({"jsonrpc": "2.0", "id": 2, "result": {}}) + sse1 = _make_sse_mock("message", data1) + sse2 = _make_sse_mock("message", data2) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [sse1, sse2] + MockEventSource.return_value = mock_es_instance + t._handle_sse_response(mock_response, ctx) + + assert q.qsize() == 1 # Only the first completion item + + def test_exception_outside_stop_puts_to_queue(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + MockEventSource.side_effect = RuntimeError("EventSource error") + t._handle_sse_response(mock_response, ctx) + + assert isinstance(q.get_nowait(), Exception) + + def test_exception_suppressed_when_stopped(self): + t = _new_transport() + t.stop_event.set() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + MockEventSource.side_effect = RuntimeError("EventSource error") + t._handle_sse_response(mock_response, ctx) + + assert q.empty() + + def test_with_metadata_resumption_callback(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + metadata = MagicMock(spec=ClientMessageMetadata) + callback = MagicMock() + metadata.on_resumption_token_update = callback + + ctx = RequestContext( + client=MagicMock(), + headers=t.request_headers, + session_id=None, + session_message=SessionMessage(_make_request_msg()), + metadata=metadata, + server_to_client_queue=q, + sse_read_timeout=60, + ) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + sse = _make_sse_mock("message", data, sse_id="resume-token") + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [sse] + MockEventSource.return_value = mock_es_instance + t._handle_sse_response(mock_response, ctx) + + callback.assert_called_once_with("resume-token") + + +# ── _handle_unexpected_content_type ────────────────────────────────────────── + + +class TestHandleUnexpectedContentTypeNew: + def test_puts_value_error_with_message(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + t._handle_unexpected_content_type("text/html", q) + item = q.get_nowait() + assert isinstance(item, ValueError) + assert "text/html" in str(item) + + +# ── _send_session_terminated_error ──────────────────────────────────────────── + + +class TestSendSessionTerminatedErrorNew: + def test_puts_jsonrpc_error(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + t._send_session_terminated_error(q, 42) + item = q.get_nowait() + assert isinstance(item, SessionMessage) + assert isinstance(item.message.root, JSONRPCError) + assert item.message.root.id == 42 + assert item.message.root.error.code == 32600 + assert "terminated" in item.message.root.error.message.lower() + + +# ── post_writer ─────────────────────────────────────────────────────────────── + + +class TestPostWriterNew: + def test_none_message_exits_loop(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + c2s.put(None) + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + def test_stop_event_exits_loop(self): + t = _new_transport() + t.stop_event.set() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + def test_initialized_notification_calls_start_get_stream(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + start_get_stream = MagicMock() + + notif_msg = _make_notification_msg("notifications/initialized") + c2s.put(SessionMessage(notif_msg)) + c2s.put(None) + + with patch.object(t, "_handle_post_request"): + t.post_writer(MagicMock(), c2s, s2c, start_get_stream) + + start_get_stream.assert_called_once() + + def test_resumption_message_calls_handle_resumption_request(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + start_get_stream = MagicMock() + + msg = SessionMessage(_make_request_msg("tools/list", 10)) + metadata = MagicMock(spec=ClientMessageMetadata) + metadata.resumption_token = "resume-abc" + msg.metadata = metadata + c2s.put(msg) + c2s.put(None) + + with patch.object(t, "_handle_resumption_request") as mock_resumption: + t.post_writer(MagicMock(), c2s, s2c, start_get_stream) + + mock_resumption.assert_called_once() + + def test_regular_message_calls_handle_post_request(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + + msg = SessionMessage(_make_request_msg("tools/list", 5)) + c2s.put(msg) + c2s.put(None) + + with patch.object(t, "_handle_post_request") as mock_post: + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + mock_post.assert_called_once() + + def test_exception_in_handler_put_to_s2c_when_not_stopped(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + + msg = SessionMessage(_make_request_msg("tools/list", 5)) + c2s.put(msg) + c2s.put(None) + + boom = RuntimeError("oops") + with patch.object(t, "_handle_post_request", side_effect=boom): + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + item = s2c.get_nowait() + assert item is boom + + def test_exception_suppressed_when_stopped(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + + msg = SessionMessage(_make_request_msg("tools/list", 5)) + c2s.put(msg) + c2s.put(None) + t.stop_event.set() + + boom = RuntimeError("oops") + with patch.object(t, "_handle_post_request", side_effect=boom): + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + assert s2c.empty() + + def test_queue_empty_timeout_continues_loop(self): + """Cover the 'except queue.Empty: continue' branch in post_writer.""" + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + call_count = {"n": 0} + + original_get = c2s.get + + def patched_get(*args, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + raise queue.Empty + + c2s.get = patched_get # type: ignore[method-assign] + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + assert call_count["n"] >= 2 + + def test_non_client_metadata_treated_as_none(self): + """session_message.metadata that's not ClientMessageMetadata → metadata is None.""" + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + + msg = SessionMessage(_make_request_msg("tools/list", 5)) + msg.metadata = "not-a-client-metadata" + c2s.put(msg) + c2s.put(None) + + with patch.object(t, "_handle_post_request") as mock_post: + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + ctx = mock_post.call_args[0][0] + assert ctx.metadata is None + + +# ── terminate_session ───────────────────────────────────────────────────────── + + +class TestTerminateSessionNew: + def test_no_session_id_skips(self): + t = _new_transport() + t.session_id = None + mock_client = MagicMock() + t.terminate_session(mock_client) + mock_client.delete.assert_not_called() + + def test_200_response_is_success(self): + t = _new_transport() + t.session_id = "sess-1" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.delete.return_value = mock_response + t.terminate_session(mock_client) + mock_client.delete.assert_called_once() + + def test_405_does_not_raise(self): + t = _new_transport() + t.session_id = "sess-1" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 405 + mock_client.delete.return_value = mock_response + t.terminate_session(mock_client) # Should not raise + + def test_non_200_logs_warning_does_not_raise(self): + t = _new_transport() + t.session_id = "sess-1" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 500 + mock_client.delete.return_value = mock_response + t.terminate_session(mock_client) # Should not raise + + def test_exception_is_swallowed(self): + t = _new_transport() + t.session_id = "sess-1" + mock_client = MagicMock() + mock_client.delete.side_effect = httpx.ConnectError("refused") + t.terminate_session(mock_client) # Should not raise + + +# ── get_session_id ──────────────────────────────────────────────────────────── + + +class TestGetSessionIdNew: + def test_returns_none_when_no_session(self): + t = _new_transport() + assert t.get_session_id() is None + + def test_returns_session_id_when_set(self): + t = _new_transport() + t.session_id = "my-session" + assert t.get_session_id() == "my-session" + + +# ── streamablehttp_client context manager ───────────────────────────────────── + + +class TestStreamablehttpClientContextManagerNew: + def test_yields_queues_and_callback(self): + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with streamablehttp_client("http://example.com/mcp") as (s2c, c2s, get_sid): + assert s2c is not None + assert c2s is not None + assert callable(get_sid) + + def test_terminate_on_close_false_does_not_delete(self): + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with streamablehttp_client("http://example.com/mcp", terminate_on_close=False) as (s2c, c2s, get_sid): + pass + mock_client.delete.assert_not_called() + + def test_queue_cleanup_on_outer_exception(self): + """Verify cleanup in finally block runs even when create_ssrf raises.""" + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_cf.side_effect = RuntimeError("connection failed") + + with pytest.raises(RuntimeError): + with streamablehttp_client("http://example.com/mcp"): + pass # pragma: no cover + + def test_timedelta_args_accepted(self): + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with streamablehttp_client( + "http://example.com/mcp", + timeout=timedelta(seconds=15), + sse_read_timeout=timedelta(seconds=60), + ) as (s2c, c2s, get_sid): + assert callable(get_sid) + + def test_start_get_stream_submits_to_executor(self): + """When context starts, post_writer is submitted to executor.""" + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + submitted_calls = [] + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + + def capture_submit(fn, *args, **kwargs): + submitted_calls.append((fn, args)) + + mock_executor.submit.side_effect = capture_submit + mock_exec.return_value = mock_executor + + with streamablehttp_client("http://example.com/mcp") as (s2c, c2s, get_sid): + pass + + # post_writer was submitted + assert len(submitted_calls) >= 1 + + def test_cleanup_puts_none_sentinels_to_queues(self): + """After context exit, None sentinels are put into both queues.""" + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with streamablehttp_client("http://example.com/mcp") as (s2c, c2s, get_sid): + pass + + # After context exit, None sentinel should be in c2s queue from cleanup + val = c2s.get_nowait() + assert val is None + + def test_terminate_called_when_session_id_set(self): + """When session_id is set and terminate_on_close=True, terminate_session is called.""" + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + mock_delete_resp = MagicMock() + mock_delete_resp.status_code = 200 + mock_client.delete.return_value = mock_delete_resp + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with patch("core.mcp.client.streamable_client.StreamableHTTPTransport") as MockTransport: + mock_transport = MockTransport.return_value + mock_transport.request_headers = { + "Accept": "application/json, text/event-stream", + "content-type": "application/json", + } + mock_transport.timeout = 30 + mock_transport.sse_read_timeout = 300 + mock_transport.session_id = "active-session" + mock_transport.stop_event = MagicMock() + mock_transport.get_session_id = MagicMock(return_value="active-session") + + with streamablehttp_client("http://example.com/mcp", terminate_on_close=True) as ( + s2c, + c2s, + get_sid, + ): + pass + + mock_transport.terminate_session.assert_called_once_with(mock_client) + + +# ── Exception hierarchy ─────────────────────────────────────────────────────── + + +class TestExceptionHierarchyNew: + def test_streamable_http_error_is_exception(self): + err = StreamableHTTPError("test") + assert isinstance(err, Exception) + + def test_resumption_error_is_streamable_http_error(self): + err = ResumptionError("test") + assert isinstance(err, StreamableHTTPError) + assert isinstance(err, Exception) + + +# ── RequestContext dataclass ────────────────────────────────────────────────── + + +class TestRequestContextNew: + def test_creation(self): + import queue + + q: queue.Queue = queue.Queue() + ctx = RequestContext( + client=MagicMock(), + headers={"X-Test": "val"}, + session_id="sid", + session_message=SessionMessage(_make_request_msg()), + metadata=None, + server_to_client_queue=q, + sse_read_timeout=30.0, + ) + assert ctx.session_id == "sid" + assert ctx.sse_read_timeout == 30.0 + assert ctx.metadata is None diff --git a/api/tests/unit_tests/core/mcp/session/test_base_session.py b/api/tests/unit_tests/core/mcp/session/test_base_session.py new file mode 100644 index 0000000000..1dd916bcf1 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/session/test_base_session.py @@ -0,0 +1,617 @@ +import queue +import time +from concurrent.futures import Future, ThreadPoolExecutor +from datetime import timedelta +from typing import Union +from unittest.mock import MagicMock, patch + +import pytest +from httpx import HTTPStatusError, Request, Response +from pydantic import BaseModel, ConfigDict, RootModel + +from core.mcp.error import MCPAuthError, MCPConnectionError +from core.mcp.session.base_session import BaseSession, RequestResponder +from core.mcp.types import ( + CancelledNotification, + ClientNotification, + ClientRequest, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCResponse, + Notification, + RequestParams, + SessionMessage, +) +from core.mcp.types import ( + Request as MCPRequest, +) + + +class MockRequestParams(RequestParams): + name: str = "default" + model_config = ConfigDict(extra="allow") + + +class MockRequest(MCPRequest[MockRequestParams, str]): + method: str = "test/request" + params: MockRequestParams = MockRequestParams() + + +class MockResult(BaseModel): + result: str + + +class MockNotificationParams(BaseModel): + message: str + + +class MockNotification(Notification[MockNotificationParams, str]): + method: str = "test/notification" + params: MockNotificationParams + + +class ReceiveRequest(RootModel[Union[MockRequest, ClientRequest]]): + pass + + +class ReceiveNotification(RootModel[Union[CancelledNotification, MockNotification, JSONRPCNotification]]): + pass + + +class MockSession(BaseSession[MockRequest, MockNotification, MockResult, ReceiveRequest, ReceiveNotification]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.received_requests = [] + self.received_notifications = [] + self.handled_incoming = [] + + def _received_request(self, responder): + self.received_requests.append(responder) + + def _received_notification(self, notification): + self.received_notifications.append(notification) + + def _handle_incoming(self, item): + self.handled_incoming.append(item) + + +@pytest.fixture +def streams(): + return queue.Queue(), queue.Queue() + + +@pytest.mark.timeout(5) +def test_request_responder_respond(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + on_complete = MagicMock() + request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test"))) + + responder = RequestResponder( + request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete + ) + + with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"): + responder.respond(MockResult(result="ok")) + + with responder as r: + r.respond(MockResult(result="ok")) + with pytest.raises(AssertionError, match="Request already responded to"): + r.respond(MockResult(result="error")) + + assert responder.completed is True + on_complete.assert_called_once_with(responder) + + msg = write_stream.get_nowait() + assert isinstance(msg.message.root, JSONRPCResponse) + assert msg.message.root.result == {"result": "ok"} + + +@pytest.mark.timeout(5) +def test_request_responder_cancel(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + on_complete = MagicMock() + request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test"))) + + responder = RequestResponder( + request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete + ) + + with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"): + responder.cancel() + + with responder as r: + r.cancel() + + assert responder.completed is True + on_complete.assert_called_once_with(responder) + + msg = write_stream.get_nowait() + assert isinstance(msg.message.root, JSONRPCError) + assert msg.message.root.error.message == "Request cancelled" + + +@pytest.mark.timeout(10) +def test_base_session_lifecycle(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session as s: + assert isinstance(s, MockSession) + assert s._executor is not None + assert s._receiver_future is not None + + session._receiver_future.result(timeout=5.0) + assert session._receiver_future.done() + + +@pytest.mark.timeout(5) +def test_send_request_success(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_response(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "hello world"}) + read_stream.put(SessionMessage(message=JSONRPCMessage(response))) + except Exception: + pass + + import threading + + t = threading.Thread(target=mock_response, daemon=True) + t.start() + + with session: + result = session.send_request(request, MockResult) + assert result.result == "hello world" + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_retry_loop_coverage(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_delayed_response(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + time.sleep(0.2) + response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "slow"}) + read_stream.put(SessionMessage(message=JSONRPCMessage(response))) + except: + pass + + import threading + + t = threading.Thread(target=mock_delayed_response, daemon=True) + t.start() + + with session: + result = session.send_request(request, MockResult, request_read_timeout_seconds=timedelta(seconds=0.1)) + assert result.result == "slow" + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_jsonrpc_error(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_error(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=-32000, message="Error")) + read_stream.put(SessionMessage(message=JSONRPCMessage(error))) + except: + pass + + import threading + + t = threading.Thread(target=mock_error, daemon=True) + t.start() + + with session: + with pytest.raises(MCPConnectionError) as exc: + session.send_request(request, MockResult) + assert exc.value.args[0].message == "Error" + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_auth_error(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_error(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=401, message="Unauthorized")) + read_stream.put(SessionMessage(message=JSONRPCMessage(error))) + except: + pass + + import threading + + t = threading.Thread(target=mock_error, daemon=True) + t.start() + + with session: + with pytest.raises(MCPAuthError): + session.send_request(request, MockResult) + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_http_status_error_coverage(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_direct_http_error(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + # To cover line 263 in base_session.py, we MUST put non-401 HTTPStatusError + # DIRECTLY into response_streams, as _receive_loop would convert it to JSONRPCError. + response = Response(status_code=403, request=Request("GET", "http://test")) + error = HTTPStatusError("Forbidden", request=response.request, response=response) + session._response_streams[req_id].put(error) + except: + pass + + import threading + + t = threading.Thread(target=mock_direct_http_error, daemon=True) + t.start() + + # We still need the session for request ID generation and queue setup + with session: + with pytest.raises(MCPConnectionError) as exc: + session.send_request(request, MockResult) + assert exc.value.args[0].code == 403 + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_http_status_auth_error(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_error(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + response = Response(status_code=401, request=Request("GET", "http://test")) + error = HTTPStatusError("Unauthorized", request=response.request, response=response) + read_stream.put(error) + except: + pass + + import threading + + t = threading.Thread(target=mock_error, daemon=True) + t.start() + + with session: + with pytest.raises(MCPAuthError): + session.send_request(request, MockResult) + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_notification(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + notification = MockNotification(method="notify", params=MockNotificationParams(message="hi")) + + session.send_notification(notification, related_request_id="rel-1") + + msg = write_stream.get_nowait() + assert isinstance(msg.message.root, JSONRPCNotification) + assert msg.message.root.method == "notify" + assert msg.message.root.params == {"message": "hi"} + assert msg.metadata.related_request_id == "rel-1" + + +@pytest.mark.timeout(10) +def test_receive_loop_request(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + req_payload = {"jsonrpc": "2.0", "id": 1, "method": "test/request", "params": {"name": "test"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload))) + + for _ in range(30): + if session.received_requests: + break + time.sleep(0.1) + + assert len(session.received_requests) == 1 + responder = session.received_requests[0] + assert responder.request_id == 1 + assert responder.request.root.method == "test/request" + + +@pytest.mark.timeout(10) +def test_receive_loop_notification(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + notif_payload = {"jsonrpc": "2.0", "method": "test/notification", "params": {"message": "hello"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload))) + + for _ in range(30): + if session.received_notifications: + break + time.sleep(0.1) + + assert len(session.received_notifications) == 1 + assert isinstance(session.received_notifications[0].root, MockNotification) + assert session.received_notifications[0].root.method == "test/notification" + + +@pytest.mark.timeout(15) +def test_receive_loop_cancel_notification(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ClientNotification) + + with session: + req_payload = {"jsonrpc": "2.0", "id": "req-1", "method": "test/request", "params": {"name": "test"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload))) + + for _ in range(30): + if "req-1" in session._in_flight: + break + time.sleep(0.1) + + assert "req-1" in session._in_flight + responder = session._in_flight["req-1"] + + with responder: + cancel_payload = {"jsonrpc": "2.0", "method": "notifications/cancelled", "params": {"requestId": "req-1"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(cancel_payload))) + + for _ in range(30): + if responder.completed: + break + time.sleep(0.1) + + assert responder.completed is True + msg = write_stream.get(timeout=2) + assert isinstance(msg.message.root, JSONRPCError) + assert msg.message.root.id == "req-1" + + +@pytest.mark.timeout(10) +def test_receive_loop_exception(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + read_stream.put(Exception("Unexpected error")) + for _ in range(30): + if any(isinstance(x, Exception) for x in session.handled_incoming): + break + time.sleep(0.1) + + assert any(isinstance(x, Exception) and str(x) == "Unexpected error" for x in session.handled_incoming) + + +@pytest.mark.timeout(10) +def test_receive_loop_http_status_error(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + session._request_id = 1 + resp_queue = queue.Queue() + session._response_streams[0] = resp_queue + + response = Response(status_code=401, request=Request("GET", "http://test")) + # Using 401 specifically as _receive_loop preserves it + error = HTTPStatusError("Unauthorized", request=response.request, response=response) + read_stream.put(error) + + got = resp_queue.get(timeout=2) + assert isinstance(got, HTTPStatusError) + + +@pytest.mark.timeout(10) +def test_receive_loop_http_status_error_non_401(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + session._request_id = 1 + resp_queue = queue.Queue() + session._response_streams[0] = resp_queue + + response = Response(status_code=500, request=Request("GET", "http://test")) + error = HTTPStatusError("Server Error", request=response.request, response=response) + read_stream.put(error) + + got = resp_queue.get(timeout=2) + assert isinstance(got, JSONRPCError) + assert got.error.code == 500 + + +@pytest.mark.timeout(5) +def test_check_receiver_status_fail(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + executor = ThreadPoolExecutor(max_workers=1) + + def raise_err(): + raise RuntimeError("Receiver failed") + + future = executor.submit(raise_err) + session._receiver_future = future + + try: + future.result() + except: + pass + + with pytest.raises(RuntimeError, match="Receiver failed"): + session.check_receiver_status() + executor.shutdown() + + +@pytest.mark.timeout(10) +def test_receive_loop_unknown_request_id(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + resp = JSONRPCResponse(jsonrpc="2.0", id=999, result={"ok": True}) + read_stream.put(SessionMessage(message=JSONRPCMessage(resp))) + + for _ in range(30): + if any(isinstance(x, RuntimeError) and "Server Error" in str(x) for x in session.handled_incoming): + break + time.sleep(0.1) + + assert any("Server Error" in str(x) for x in session.handled_incoming) + + +@pytest.mark.timeout(10) +def test_receive_loop_http_error_unknown_id(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + response = Response(status_code=401, request=Request("GET", "http://test")) + error = HTTPStatusError("Unauthorized", request=response.request, response=response) + read_stream.put(error) + + for _ in range(30): + if any(isinstance(x, RuntimeError) and "unknown request ID" in str(x) for x in session.handled_incoming): + break + time.sleep(0.1) + + assert any("unknown request ID" in str(x) for x in session.handled_incoming) + + +@pytest.mark.timeout(10) +def test_receive_loop_validation_error_notification(streams): + from core.mcp.session.base_session import logger + + with patch.object(logger, "warning") as mock_warning: + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, RootModel[MockNotification]) + + with session: + notif_payload = {"jsonrpc": "2.0", "method": "bad", "params": {"some": "data"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload))) + time.sleep(1.0) + + assert mock_warning.called + + +@pytest.mark.timeout(5) +def test_send_request_none_response(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_none(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + session._response_streams[req_id].put(None) + except: + pass + + import threading + + t = threading.Thread(target=mock_none, daemon=True) + t.start() + + with session: + with pytest.raises(MCPConnectionError) as exc: + session.send_request(request, MockResult) + assert exc.value.args[0].message == "No response received" + t.join(timeout=1) + + +@pytest.mark.timeout(15) +def test_session_exit_timeout(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + mock_future = MagicMock(spec=Future) + mock_future.result.side_effect = TimeoutError() + mock_future.done.return_value = False + + session._receiver_future = mock_future + session._executor = MagicMock(spec=ThreadPoolExecutor) + + session.__exit__(None, None, None) + + mock_future.cancel.assert_called_once() + session._executor.shutdown.assert_called_once_with(wait=False) + + +@pytest.mark.timeout(10) +def test_receive_loop_fatal_exception(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with patch.object(read_stream, "get", side_effect=RuntimeError("Fatal loop error")): + with patch("core.mcp.session.base_session.logger") as mock_logger: + with pytest.raises(RuntimeError, match="Fatal loop error"): + with session: + pass + mock_logger.exception.assert_called_with("Error in message processing loop") + + +@pytest.mark.timeout(5) +def test_receive_loop_empty_coverage(streams): + with patch("core.mcp.session.base_session.DEFAULT_RESPONSE_READ_TIMEOUT", 0.1): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + with session: + time.sleep(0.3) + + +@pytest.mark.timeout(2) +def test_base_methods_noop(streams): + read_stream, write_stream = streams + session = BaseSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + session._received_request(MagicMock()) + session._received_notification(MagicMock()) + session.send_progress_notification("token", 0.5) + session._handle_incoming(MagicMock()) + + +@pytest.mark.timeout(5) +def test_send_request_session_timeout_retry_6(streams): + read_stream, write_stream = streams + session = MockSession( + read_stream, write_stream, ReceiveRequest, ReceiveNotification, read_timeout_seconds=timedelta(seconds=0.1) + ) + + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + with patch.object(session, "check_receiver_status", side_effect=[None, RuntimeError("timeout_broken")]): + with pytest.raises(RuntimeError, match="timeout_broken"): + session.send_request(request, MockResult) diff --git a/api/tests/unit_tests/core/mcp/session/test_client_session.py b/api/tests/unit_tests/core/mcp/session/test_client_session.py new file mode 100644 index 0000000000..c7b9d3cfa9 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/session/test_client_session.py @@ -0,0 +1,576 @@ +import queue +from unittest.mock import MagicMock + +import pytest +from pydantic import AnyUrl + +from core.mcp import types +from core.mcp.session.base_session import RequestResponder, SessionMessage +from core.mcp.session.client_session import ( + ClientSession, + _default_list_roots_callback, + _default_logging_callback, + _default_message_handler, + _default_sampling_callback, +) + + +@pytest.fixture +def streams(): + return queue.Queue(), queue.Queue() + + +def test_client_session_init(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + assert session._client_info.name == "Dify" + assert session._sampling_callback == _default_sampling_callback + assert session._list_roots_callback == _default_list_roots_callback + assert session._logging_callback == _default_logging_callback + assert session._message_handler == _default_message_handler + + +def test_client_session_init_custom(streams): + read_stream, write_stream = streams + sampling_cb = MagicMock() + list_roots_cb = MagicMock() + logging_cb = MagicMock() + msg_handler = MagicMock() + client_info = types.Implementation(name="Custom", version="1.0") + + session = ClientSession( + read_stream, + write_stream, + sampling_callback=sampling_cb, + list_roots_callback=list_roots_cb, + logging_callback=logging_cb, + message_handler=msg_handler, + client_info=client_info, + ) + + assert session._client_info == client_info + assert session._sampling_callback == sampling_cb + assert session._list_roots_callback == list_roots_cb + assert session._logging_callback == logging_cb + assert session._message_handler == msg_handler + + +def test_initialize_success(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + expected_result = types.InitializeResult( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ServerCapabilities(), + serverInfo=types.Implementation(name="test-server", version="1.0"), + ) + + def mock_server(): + # Handle initialize request + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result=expected_result.model_dump()) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + # Expect initialized notification + notif = write_stream.get(timeout=2) + assert notif.message.root.method == "notifications/initialized" + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.initialize() + assert result.protocolVersion == types.LATEST_PROTOCOL_VERSION + assert result.serverInfo.name == "test-server" + + t.join(timeout=1) + + +def test_initialize_custom_capabilities(streams): + read_stream, write_stream = streams + session = ClientSession( + read_stream, write_stream, sampling_callback=lambda c, p: None, list_roots_callback=lambda c: None + ) + + def mock_server(): + msg = write_stream.get(timeout=2) + params = msg.message.root.params + # Check that capabilities are set because we provided custom callbacks + assert params["capabilities"]["sampling"] is not None + assert params["capabilities"]["roots"]["listChanged"] is True + + req_id = msg.message.root.id + resp = types.JSONRPCResponse( + jsonrpc="2.0", + id=req_id, + result={ + "protocolVersion": types.LATEST_PROTOCOL_VERSION, + "capabilities": {}, + "serverInfo": {"name": "test", "version": "1.0"}, + }, + ) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + write_stream.get(timeout=2) # initialized notif + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.initialize() + t.join(timeout=1) + + +def test_initialize_unsupported_version(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + resp = types.JSONRPCResponse( + jsonrpc="2.0", + id=req_id, + result={ + "protocolVersion": "0.0.1", # Unsupported + "capabilities": {}, + "serverInfo": {"name": "test", "version": "1.0"}, + }, + ) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + with pytest.raises(RuntimeError, match="Unsupported protocol version"): + session.initialize() + t.join(timeout=1) + + +def test_send_ping(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "ping" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.send_ping() + t.join(timeout=1) + + +def test_send_progress_notification(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + session.send_progress_notification(progress_token="token", progress=50.0, total=100.0) + + msg = write_stream.get_nowait() + assert msg.message.root.method == "notifications/progress" + assert msg.message.root.params["progressToken"] == "token" + assert msg.message.root.params["progress"] == 50.0 + assert msg.message.root.params["total"] == 100.0 + + +def test_set_logging_level(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "logging/setLevel" + assert msg.message.root.params["level"] == "debug" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.set_logging_level("debug") + t.join(timeout=1) + + +def test_list_resources(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/list" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"resources": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.list_resources() + assert result.resources == [] + t.join(timeout=1) + + +def test_list_resource_templates(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/templates/list" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"resourceTemplates": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.list_resource_templates() + assert result.resourceTemplates == [] + t.join(timeout=1) + + +def test_read_resource(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + uri = AnyUrl("file:///test") + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/read" + assert msg.message.root.params["uri"] == str(uri) + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"contents": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.read_resource(uri) + assert result.contents == [] + t.join(timeout=1) + + +def test_subscribe_resource(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + uri = AnyUrl("file:///test") + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/subscribe" + assert msg.message.root.params["uri"] == str(uri) + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.subscribe_resource(uri) + t.join(timeout=1) + + +def test_unsubscribe_resource(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + uri = AnyUrl("file:///test") + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/unsubscribe" + assert msg.message.root.params["uri"] == str(uri) + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.unsubscribe_resource(uri) + t.join(timeout=1) + + +def test_call_tool(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "tools/call" + assert msg.message.root.params["name"] == "test-tool" + assert msg.message.root.params["arguments"] == {"arg": 1} + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"content": [], "isError": False}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.call_tool("test-tool", arguments={"arg": 1}) + assert result.isError is False + t.join(timeout=1) + + +def test_list_prompts(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "prompts/list" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"prompts": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.list_prompts() + assert result.prompts == [] + t.join(timeout=1) + + +def test_get_prompt(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "prompts/get" + assert msg.message.root.params["name"] == "test-prompt" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"messages": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.get_prompt("test-prompt") + assert result.messages == [] + t.join(timeout=1) + + +def test_complete(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + ref = types.PromptReference(type="ref/prompt", name="test") + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "completion/complete" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"completion": {"values": [], "hasMore": False}}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.complete(ref, argument={"name": "val", "value": "x"}) + assert result.completion.hasMore is False + t.join(timeout=1) + + +def test_list_tools(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "tools/list" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"tools": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.list_tools() + assert result.tools == [] + t.join(timeout=1) + + +def test_send_roots_list_changed(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + session.send_roots_list_changed() + + msg = write_stream.get_nowait() + assert msg.message.root.method == "notifications/roots/list_changed" + + +def test_received_request_sampling(streams): + read_stream, write_stream = streams + sampling_cb = MagicMock( + return_value=types.CreateMessageResult( + role="assistant", content=types.TextContent(type="text", text="hello"), model="gpt-4" + ) + ) + session = ClientSession(read_stream, write_stream, sampling_callback=sampling_cb) + + req = types.ServerRequest( + root=types.CreateMessageRequest( + method="sampling/createMessage", params=types.CreateMessageRequestParams(messages=[], maxTokens=100) + ) + ) + + responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock()) + + session._received_request(responder) + + msg = write_stream.get_nowait() + assert msg.message.root.result["model"] == "gpt-4" + sampling_cb.assert_called_once() + + +def test_received_request_list_roots(streams): + read_stream, write_stream = streams + list_roots_cb = MagicMock(return_value=types.ListRootsResult(roots=[])) + session = ClientSession(read_stream, write_stream, list_roots_callback=list_roots_cb) + + req = types.ServerRequest(root=types.ListRootsRequest(method="roots/list")) + + responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock()) + + session._received_request(responder) + + msg = write_stream.get_nowait() + assert msg.message.root.result["roots"] == [] + list_roots_cb.assert_called_once() + + +def test_received_request_ping(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + req = types.ServerRequest(root=types.PingRequest(method="ping")) + + responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock()) + + session._received_request(responder) + + msg = write_stream.get_nowait() + assert msg.message.root.result == {} + + +def test_handle_incoming(streams): + read_stream, write_stream = streams + msg_handler = MagicMock() + session = ClientSession(read_stream, write_stream, message_handler=msg_handler) + + item = MagicMock() + session._handle_incoming(item) + msg_handler.assert_called_once_with(item) + + +def test_received_notification_logging(streams): + read_stream, write_stream = streams + logging_cb = MagicMock() + session = ClientSession(read_stream, write_stream, logging_callback=logging_cb) + + notif = types.ServerNotification( + root=types.LoggingMessageNotification( + method="notifications/message", + params=types.LoggingMessageNotificationParams(level="info", data={"msg": "test"}), + ) + ) + + session._received_notification(notif) + logging_cb.assert_called_once() + assert logging_cb.call_args[0][0].level == "info" + + +def test_default_message_handler(): + # Exception case + with pytest.raises(ValueError, match="test error"): + _default_message_handler(Exception("test error")) + + # Notification case - should do nothing + _default_message_handler(MagicMock(spec=types.ServerNotification)) + + # RequestResponder case - should do nothing + _default_message_handler(MagicMock(spec=RequestResponder)) + + +def test_default_sampling_callback(): + ctx = MagicMock() + params = MagicMock() + res = _default_sampling_callback(ctx, params) + assert res.code == types.INVALID_REQUEST + assert "not supported" in res.message + + +def test_default_list_roots_callback(): + ctx = MagicMock() + res = _default_list_roots_callback(ctx) + assert res.code == types.INVALID_REQUEST + assert "not supported" in res.message + + +def test_default_logging_callback(): + params = MagicMock() + _default_logging_callback(params) # Should do nothing + + +def test_received_notification_unknown(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + # Use a notification type that is NOT LoggingMessageNotification + notif = types.ServerNotification( + root=types.ResourceListChangedNotification(method="notifications/resources/list_changed") + ) + + session._received_notification(notif) + # Should just pass (case _:) diff --git a/api/tests/unit_tests/core/mcp/test_mcp_client.py b/api/tests/unit_tests/core/mcp/test_mcp_client.py index c0420d3371..c245b4a77e 100644 --- a/api/tests/unit_tests/core/mcp/test_mcp_client.py +++ b/api/tests/unit_tests/core/mcp/test_mcp_client.py @@ -2,13 +2,16 @@ from contextlib import ExitStack from types import TracebackType -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest +from sqlalchemy.orm import Session -from core.mcp.error import MCPConnectionError +from core.entities.mcp_provider import MCPProviderEntity +from core.mcp.auth_client import MCPClientWithAuthRetry +from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.mcp_client import MCPClient -from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations +from core.mcp.types import CallToolResult, ListToolsResult, OAuthTokens, TextContent, Tool, ToolAnnotations class TestMCPClient: @@ -380,3 +383,256 @@ class TestMCPClient: timeout=30.0, sse_read_timeout=60.0, ) + + +class TestMCPClientWithAuthRetry: + """Test suite for MCPClientWithAuthRetry.""" + + @pytest.fixture + def mock_provider(self): + provider = MagicMock(spec=MCPProviderEntity) + provider.id = "test-provider-id" + provider.tenant_id = "test-tenant-id" + provider.retrieve_tokens.return_value = OAuthTokens( + access_token="new-token", + token_type="Bearer", + expires_in=3600, + refresh_token="refresh-token", + ) + return provider + + @pytest.fixture + def auth_client(self, mock_provider): + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + headers={"Authorization": "Bearer old-token"}, + provider_entity=mock_provider, + authorization_code="test-code", + by_server_id=True, + ) + return client + + def test_init(self, mock_provider): + """Test initialization.""" + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + headers={"Authorization": "Bearer test"}, + timeout=30.0, + provider_entity=mock_provider, + authorization_code="initial-code", + by_server_id=True, + ) + + assert client.server_url == "http://test.example.com" + assert client.headers == {"Authorization": "Bearer test"} + assert client.timeout == 30.0 + assert client.provider_entity == mock_provider + assert client.authorization_code == "initial-code" + assert client.by_server_id is True + assert client._has_retried is False + + @patch("core.mcp.auth_client.db") + @patch("core.mcp.auth_client.Session") + @patch("services.tools.mcp_tools_manage_service.MCPToolManageService") + def test_handle_auth_error_success( + self, mock_service_class, mock_session_class, mock_db, auth_client, mock_provider + ): + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_service = mock_service_class.return_value + new_provider = MagicMock(spec=MCPProviderEntity) + new_provider.retrieve_tokens.return_value = OAuthTokens( + access_token="new-access-token", + token_type="Bearer", + expires_in=3600, + refresh_token="new-refresh-token", + ) + mock_service.get_provider_entity.return_value = new_provider + + # MCPAuthError parses resource_metadata and scope from www_authenticate_header + www_auth = 'Bearer resource_metadata="http://meta", scope="read"' + error = MCPAuthError("Auth failed", www_authenticate_header=www_auth) + + auth_client._handle_auth_error(error) + + # Verify service calls - error.resource_metadata_url and error.scope_hint are parsed from header + mock_service.auth_with_actions.assert_called_once_with( + mock_provider, + "test-code", + resource_metadata_url="http://meta", + scope_hint="read", + ) + mock_service.get_provider_entity.assert_called_once_with( + mock_provider.id, mock_provider.tenant_id, by_server_id=True + ) + + # Verify client updates + assert auth_client.headers["Authorization"] == "Bearer new-access-token" + assert auth_client.authorization_code is None + assert auth_client._has_retried is True + assert auth_client.provider_entity == new_provider + + def test_handle_auth_error_no_provider(self, auth_client): + """Test auth error handling when no provider entity is set.""" + auth_client.provider_entity = None + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert exc_info.value == error + + def test_handle_auth_error_already_retried(self, auth_client): + """Test auth error handling when already retried.""" + auth_client._has_retried = True + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert exc_info.value == error + + @patch("core.mcp.auth_client.db") + @patch("core.mcp.auth_client.Session") + @patch("services.tools.mcp_tools_manage_service.MCPToolManageService") + def test_handle_auth_error_no_token( + self, mock_service_class, mock_session_class, mock_db, auth_client, mock_provider + ): + """Test auth error handling when no token is received.""" + mock_session_class.return_value.__enter__.return_value = MagicMock() + mock_service = mock_service_class.return_value + + new_provider = MagicMock(spec=MCPProviderEntity) + new_provider.retrieve_tokens.return_value = None + mock_service.get_provider_entity.return_value = new_provider + + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert "Authentication failed - no token received" in str(exc_info.value) + + @patch("core.mcp.auth_client.db") + @patch("core.mcp.auth_client.Session") + @patch("services.tools.mcp_tools_manage_service.MCPToolManageService") + def test_handle_auth_error_generic_exception(self, mock_service_class, mock_session_class, mock_db, auth_client): + """Test auth error handling when a generic exception occurs.""" + mock_session_class.side_effect = Exception("DB error") + + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert "Authentication retry failed: DB error" in str(exc_info.value) + + @patch("core.mcp.auth_client.db") + @patch("core.mcp.auth_client.Session") + @patch("services.tools.mcp_tools_manage_service.MCPToolManageService") + def test_handle_auth_error_mcp_auth_error_propagation( + self, mock_service_class, mock_session_class, mock_db, auth_client + ): + """Test that MCPAuthError during refresh is propagated as is.""" + mock_session_class.return_value.__enter__.return_value = MagicMock() + mock_service = mock_service_class.return_value + mock_service.auth_with_actions.side_effect = MCPAuthError("Refresh failed") + + error = MCPAuthError("Initial auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert "Refresh failed" in str(exc_info.value) + + def test_execute_with_retry_success_first_try(self, auth_client): + """Test execution success on first try.""" + mock_func = MagicMock(return_value="success") + + result = auth_client._execute_with_retry(mock_func, "arg1", kwarg1="val1") + + assert result == "success" + mock_func.assert_called_once_with("arg1", kwarg1="val1") + assert auth_client._has_retried is False + + @patch.object(MCPClientWithAuthRetry, "_handle_auth_error") + @patch.object(MCPClientWithAuthRetry, "_initialize") + def test_execute_with_retry_success_on_retry_initialized(self, mock_initialize, mock_handle_auth, auth_client): + """Test execution success on retry after auth error when client was already initialized.""" + mock_func = MagicMock() + mock_func.side_effect = [MCPAuthError("Auth failed"), "success"] + + auth_client._initialized = True + auth_client._exit_stack = MagicMock() + + result = auth_client._execute_with_retry(mock_func, "arg") + + assert result == "success" + assert mock_func.call_count == 2 + mock_handle_auth.assert_called_once() + mock_initialize.assert_called_once() + auth_client._exit_stack.close.assert_called_once() + assert auth_client._has_retried is False + + @patch.object(MCPClientWithAuthRetry, "_handle_auth_error") + @patch.object(MCPClientWithAuthRetry, "_initialize") + def test_execute_with_retry_success_on_retry_not_initialized(self, mock_initialize, mock_handle_auth, auth_client): + """Test retry when client was NOT initialized (skips cleanup/re-init).""" + mock_func = MagicMock() + mock_func.side_effect = [MCPAuthError("Auth failed"), "result"] + + auth_client._initialized = False + + result = auth_client._execute_with_retry(mock_func, "arg") + + assert result == "result" + assert mock_func.call_count == 2 + mock_handle_auth.assert_called_once() + mock_initialize.assert_not_called() + assert auth_client._has_retried is False + + @patch.object(MCPClientWithAuthRetry, "_handle_auth_error") + def test_execute_with_retry_failure_on_retry(self, mock_handle_auth, auth_client): + """Test execution failure even after retry.""" + mock_func = MagicMock() + mock_func.side_effect = [MCPAuthError("First fail"), MCPAuthError("Second fail")] + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._execute_with_retry(mock_func, "arg") + + assert "Second fail" in str(exc_info.value) + assert mock_func.call_count == 2 + mock_handle_auth.assert_called_once() + assert auth_client._has_retried is False + + @patch.object(MCPClientWithAuthRetry, "_execute_with_retry") + def test_auth_client_context_manager_enter(self, mock_execute_retry, auth_client): + """Test context manager __enter__.""" + auth_client.__enter__() + + mock_execute_retry.assert_called_once() + func = mock_execute_retry.call_args[0][0] + + with patch("core.mcp.mcp_client.MCPClient.__enter__") as mock_base_enter: + result = func() + assert result == auth_client + mock_base_enter.assert_called_once() + + @patch.object(MCPClientWithAuthRetry, "_execute_with_retry") + def test_auth_client_list_tools(self, mock_execute_retry, auth_client): + """Test list_tools with retry.""" + auth_client.list_tools() + + mock_execute_retry.assert_called_once() + assert mock_execute_retry.call_args[0][0].__name__ == "list_tools" + + @patch.object(MCPClientWithAuthRetry, "_execute_with_retry") + def test_auth_client_invoke_tool(self, mock_execute_retry, auth_client): + """Test invoke_tool with retry.""" + auth_client.invoke_tool("test-tool", {"arg": "val"}) + + mock_execute_retry.assert_called_once() + assert mock_execute_retry.call_args[0][0].__name__ == "invoke_tool" + assert mock_execute_retry.call_args[0][1] == "test-tool" + assert mock_execute_retry.call_args[0][2] == {"arg": "val"} diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py new file mode 100644 index 0000000000..5ecfe01808 --- /dev/null +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -0,0 +1,969 @@ +"""Comprehensive unit tests for core/memory/token_buffer_memory.py""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from core.memory.token_buffer_memory import TokenBufferMemory +from dify_graph.model_runtime.entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, + UserPromptMessage, +) +from models.model import AppMode + +# --------------------------------------------------------------------------- +# Helpers / shared fixtures +# --------------------------------------------------------------------------- + + +def _make_conversation(mode: AppMode = AppMode.CHAT) -> MagicMock: + """Return a minimal Conversation mock.""" + conv = MagicMock() + conv.id = str(uuid4()) + conv.mode = mode + conv.model_config = {} + return conv + + +def _make_model_instance() -> MagicMock: + """Return a ModelInstance mock whose token counter returns a constant.""" + mi = MagicMock() + mi.get_llm_num_tokens.return_value = 100 + return mi + + +def _make_message(answer: str = "hello", answer_tokens: int = 5) -> MagicMock: + msg = MagicMock() + msg.id = str(uuid4()) + msg.query = "user query" + msg.answer = answer + msg.answer_tokens = answer_tokens + msg.workflow_run_id = str(uuid4()) + msg.created_at = MagicMock() + return msg + + +# =========================================================================== +# Tests for __init__ and workflow_run_repo property +# =========================================================================== + + +class TestInit: + def test_init_stores_conversation_and_model_instance(self): + conv = _make_conversation() + mi = _make_model_instance() + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + assert mem.conversation is conv + assert mem.model_instance is mi + assert mem._workflow_run_repo is None + + def test_workflow_run_repo_is_created_lazily(self): + conv = _make_conversation() + mi = _make_model_instance() + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + mock_repo = MagicMock() + with ( + patch("core.memory.token_buffer_memory.sessionmaker") as mock_sm, + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=mock_repo, + ), + ): + mock_db.engine = MagicMock() + repo = mem.workflow_run_repo + assert repo is mock_repo + assert mem._workflow_run_repo is mock_repo + + def test_workflow_run_repo_cached_after_first_access(self): + conv = _make_conversation() + mi = _make_model_instance() + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + existing_repo = MagicMock() + mem._workflow_run_repo = existing_repo + + with patch( + "core.memory.token_buffer_memory.DifyAPIRepositoryFactory.create_api_workflow_run_repository" + ) as mock_factory: + repo = mem.workflow_run_repo + mock_factory.assert_not_called() + assert repo is existing_repo + + +# =========================================================================== +# Tests for _build_prompt_message_with_files +# =========================================================================== + + +class TestBuildPromptMessageWithFiles: + """Tests for the private _build_prompt_message_with_files method.""" + + # ------------------------------------------------------------------ + # Mode: CHAT / AGENT_CHAT / COMPLETION (simple branch) + # ------------------------------------------------------------------ + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_no_files_user_message(self, mode): + """When file_extra_config is falsy or app_record is None → plain UserPromptMessage.""" + conv = _make_conversation(mode) + mi = _make_model_instance() + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + with patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, # falsy → file_objs = [] + ): + result = mem._build_prompt_message_with_files( + message_files=[], + text_content="hello", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + assert isinstance(result, UserPromptMessage) + assert result.content == "hello" + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_no_files_assistant_message(self, mode): + """Plain AssistantPromptMessage when no files and is_user_message=False.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + with patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ): + result = mem._build_prompt_message_with_files( + message_files=[], + text_content="ai reply", + message=_make_message(), + app_record=None, + is_user_message=False, + ) + + assert isinstance(result, AssistantPromptMessage) + assert result.content == "ai reply" + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_with_files_user_message(self, mode): + """When files are present, returns UserPromptMessage with list content.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_file_extra_config = MagicMock() + mock_file_extra_config.image_config = None # no detail override + + mock_file_obj = MagicMock() + # Must be a real entity so Pydantic's tagged union discriminator can validate it + real_image_content = ImagePromptMessageContent( + url="http://example.com/img.png", format="png", mime_type="image/png" + ) + + mock_message_file = MagicMock() + mock_app_record = MagicMock() + mock_app_record.tenant_id = "tenant-1" + + with ( + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ), + patch( + "core.memory.token_buffer_memory.file_factory.build_from_message_file", + return_value=mock_file_obj, + ), + patch( + "core.memory.token_buffer_memory.file_manager.to_prompt_message_content", + return_value=real_image_content, + ), + ): + result = mem._build_prompt_message_with_files( + message_files=[mock_message_file], + text_content="user text", + message=_make_message(), + app_record=mock_app_record, + is_user_message=True, + ) + + assert isinstance(result, UserPromptMessage) + assert isinstance(result.content, list) + # Last element should be TextPromptMessageContent + assert isinstance(result.content[-1], TextPromptMessageContent) + assert result.content[-1].data == "user text" + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_with_files_assistant_message(self, mode): + """When files are present, returns AssistantPromptMessage with list content.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_file_extra_config = MagicMock() + mock_file_extra_config.image_config = None + + mock_file_obj = MagicMock() + real_image_content = ImagePromptMessageContent( + url="http://example.com/img.png", format="png", mime_type="image/png" + ) + mock_app_record = MagicMock() + mock_app_record.tenant_id = "tenant-1" + + with ( + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ), + patch( + "core.memory.token_buffer_memory.file_factory.build_from_message_file", + return_value=mock_file_obj, + ), + patch( + "core.memory.token_buffer_memory.file_manager.to_prompt_message_content", + return_value=real_image_content, + ), + ): + result = mem._build_prompt_message_with_files( + message_files=[MagicMock()], + text_content="ai text", + message=_make_message(), + app_record=mock_app_record, + is_user_message=False, + ) + + assert isinstance(result, AssistantPromptMessage) + assert isinstance(result.content, list) + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_with_files_image_detail_overridden(self, mode): + """When image_config.detail is set, detail is taken from config.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_image_config = MagicMock() + mock_image_config.detail = ImagePromptMessageContent.DETAIL.LOW + + mock_file_extra_config = MagicMock() + mock_file_extra_config.image_config = mock_image_config + + mock_app_record = MagicMock() + mock_app_record.tenant_id = "tenant-1" + + real_image_content = ImagePromptMessageContent( + url="http://example.com/img.png", format="png", mime_type="image/png" + ) + + with ( + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ), + patch( + "core.memory.token_buffer_memory.file_factory.build_from_message_file", + return_value=MagicMock(), + ), + patch( + "core.memory.token_buffer_memory.file_manager.to_prompt_message_content", + return_value=real_image_content, + ) as mock_to_prompt, + ): + mem._build_prompt_message_with_files( + message_files=[MagicMock()], + text_content="user text", + message=_make_message(), + app_record=mock_app_record, + is_user_message=True, + ) + # Ensure the LOW detail was passed through + mock_to_prompt.assert_called_once_with( + mock_to_prompt.call_args[0][0], image_detail_config=ImagePromptMessageContent.DETAIL.LOW + ) + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_app_record_none_returns_empty_file_objs(self, mode): + """app_record=None path → file_objs stays empty → plain messages.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_file_extra_config = MagicMock() + + with patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ): + result = mem._build_prompt_message_with_files( + message_files=[MagicMock()], + text_content="hello", + message=_make_message(), + app_record=None, # <-- forces the else branch → file_objs = [] + is_user_message=True, + ) + + assert isinstance(result, UserPromptMessage) + assert result.content == "hello" + + # ------------------------------------------------------------------ + # Mode: ADVANCED_CHAT / WORKFLOW + # ------------------------------------------------------------------ + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_no_app_raises(self, mode): + """Raises ValueError when conversation.app is falsy.""" + conv = _make_conversation(mode) + conv.app = None + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + with pytest.raises(ValueError, match="App not found for conversation"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_no_workflow_run_id_raises(self, mode): + """Raises ValueError when message.workflow_run_id is falsy.""" + conv = _make_conversation(mode) + conv.app = MagicMock() + + message = _make_message() + message.workflow_run_id = None # force missing + + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + with pytest.raises(ValueError, match="Workflow run ID not found"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=message, + app_record=MagicMock(), + is_user_message=True, + ) + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_workflow_run_not_found_raises(self, mode): + """Raises ValueError when workflow_run_repo returns None.""" + conv = _make_conversation(mode) + mock_app = MagicMock() + conv.app = mock_app + + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + mem._workflow_run_repo = MagicMock() + mem._workflow_run_repo.get_workflow_run_by_id.return_value = None + + with pytest.raises(ValueError, match="Workflow run not found"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_workflow_not_found_raises(self, mode): + """Raises ValueError when Workflow lookup returns None.""" + conv = _make_conversation(mode) + conv.app = MagicMock() + + mock_workflow_run = MagicMock() + mock_workflow_run.workflow_id = str(uuid4()) + + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + mem._workflow_run_repo = MagicMock() + mem._workflow_run_repo.get_workflow_run_by_id.return_value = mock_workflow_run + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + ): + mock_db.session.scalar.return_value = None # workflow not found + + with pytest.raises(ValueError, match="Workflow not found"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_success_no_files_user(self, mode): + """Happy path: workflow mode, no message files → plain UserPromptMessage.""" + conv = _make_conversation(mode) + conv.app = MagicMock() + + mock_workflow_run = MagicMock() + mock_workflow_run.workflow_id = str(uuid4()) + + mock_workflow = MagicMock() + mock_workflow.features_dict = {} + + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + mem._workflow_run_repo = MagicMock() + mem._workflow_run_repo.get_workflow_run_by_id.return_value = mock_workflow_run + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalar.return_value = mock_workflow + + result = mem._build_prompt_message_with_files( + message_files=[], + text_content="wf text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + assert isinstance(result, UserPromptMessage) + assert result.content == "wf text" + + # ------------------------------------------------------------------ + # Invalid mode + # ------------------------------------------------------------------ + + def test_invalid_mode_raises_assertion(self): + """Any unknown AppMode raises AssertionError.""" + conv = _make_conversation() + conv.mode = "unknown_mode" # not in any set + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + with pytest.raises(AssertionError, match="Invalid app mode"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + +# =========================================================================== +# Tests for get_history_prompt_messages +# =========================================================================== + + +class TestGetHistoryPromptMessages: + """Tests for get_history_prompt_messages.""" + + def _make_memory(self, mode: AppMode = AppMode.CHAT) -> TokenBufferMemory: + conv = _make_conversation(mode) + conv.app = MagicMock() + return TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + def test_returns_empty_when_no_messages(self): + mem = self._make_memory() + with patch("core.memory.token_buffer_memory.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [] + result = mem.get_history_prompt_messages() + assert result == [] + + def test_skips_first_message_without_answer(self): + """The newest message (index 0 after extraction) without answer and tokens==0 is skipped.""" + mem = self._make_memory() + + msg_no_answer = _make_message(answer="", answer_tokens=0) + msg_no_answer.parent_message_id = None # ensures extract_thread_messages returns it + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg_no_answer], + ), + ): + mock_db.session.scalars.return_value.all.side_effect = [ + [msg_no_answer], # first call: messages query + [], # second call: user files query (never hit, but safe) + ] + result = mem.get_history_prompt_messages() + + assert result == [] + + def test_message_with_answer_not_skipped(self): + """A message with a non-empty answer is NOT popped.""" + mem = self._make_memory() + + msg = _make_message(answer="some answer", answer_tokens=10) + msg.parent_message_id = None + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + # user files query → empty; assistant files query → empty + mock_db.session.scalars.return_value.all.return_value = [] + result = mem.get_history_prompt_messages() + + assert len(result) == 2 # one user + one assistant + + def test_message_limit_default_is_500(self): + """When message_limit is None the stmt is limited to 500.""" + mem = self._make_memory() + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch("core.memory.token_buffer_memory.select") as mock_select, + patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]), + ): + mock_stmt = MagicMock() + mock_select.return_value.where.return_value.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_db.session.scalars.return_value.all.return_value = [] + + mem.get_history_prompt_messages(message_limit=None) + mock_stmt.limit.assert_called_with(500) + + def test_message_limit_clipped_to_500(self): + """A message_limit > 500 is clamped to 500.""" + mem = self._make_memory() + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch("core.memory.token_buffer_memory.select") as mock_select, + patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]), + ): + mock_stmt = MagicMock() + mock_select.return_value.where.return_value.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_db.session.scalars.return_value.all.return_value = [] + + mem.get_history_prompt_messages(message_limit=9999) + mock_stmt.limit.assert_called_with(500) + + def test_message_limit_positive_used(self): + """A positive message_limit < 500 is used as-is.""" + mem = self._make_memory() + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch("core.memory.token_buffer_memory.select") as mock_select, + patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]), + ): + mock_stmt = MagicMock() + mock_select.return_value.where.return_value.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_db.session.scalars.return_value.all.return_value = [] + + mem.get_history_prompt_messages(message_limit=10) + mock_stmt.limit.assert_called_with(10) + + def test_message_limit_zero_uses_default(self): + """message_limit=0 triggers the else branch → default 500.""" + mem = self._make_memory() + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch("core.memory.token_buffer_memory.select") as mock_select, + patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]), + ): + mock_stmt = MagicMock() + mock_select.return_value.where.return_value.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_db.session.scalars.return_value.all.return_value = [] + + mem.get_history_prompt_messages(message_limit=0) + mock_stmt.limit.assert_called_with(500) + + def test_user_files_cause_build_with_files_call(self): + """When user_files is non-empty _build_prompt_message_with_files is invoked.""" + mem = self._make_memory() + msg = _make_message() + msg.parent_message_id = None + + mock_user_file = MagicMock() + mock_user_prompt = UserPromptMessage(content="from build") + mock_assistant_prompt = AssistantPromptMessage(content="answer") + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + # messages query + r.all.return_value = [msg] + elif call_count["n"] == 1: + # user files + r.all.return_value = [mock_user_file] + else: + # assistant files + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch.object( + mem, + "_build_prompt_message_with_files", + side_effect=[mock_user_prompt, mock_assistant_prompt], + ) as mock_build, + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages() + + assert mock_build.call_count >= 1 + # First call should be user message + first_call_kwargs = mock_build.call_args_list[0][1] + assert first_call_kwargs["is_user_message"] is True + + def test_assistant_files_cause_build_with_files_call(self): + """When assistant_files is non-empty, build is called with is_user_message=False.""" + mem = self._make_memory() + msg = _make_message() + msg.parent_message_id = None + + mock_assistant_file = MagicMock() + mock_user_prompt = UserPromptMessage(content="query") + mock_assistant_prompt = AssistantPromptMessage(content="built") + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + elif call_count["n"] == 1: + r.all.return_value = [] # no user files + else: + r.all.return_value = [mock_assistant_file] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch.object( + mem, + "_build_prompt_message_with_files", + return_value=mock_assistant_prompt, + ) as mock_build, + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages() + + mock_build.assert_called_once() + call_kwargs = mock_build.call_args[1] + assert call_kwargs["is_user_message"] is False + + def test_token_pruning_removes_oldest_messages(self): + """If tokens exceed limit, oldest messages are removed until within limit.""" + conv = _make_conversation() + conv.app = MagicMock() + + # Model returns tokens that decrease only after removing pairs + token_values = [3000, 1500] # first call over limit, second within + mi = MagicMock() + mi.get_llm_num_tokens.side_effect = token_values + + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + msg = _make_message() + msg.parent_message_id = None + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + else: + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages(max_token_limit=2000) + + # After pruning, we should have fewer than the 2 initial messages + assert len(result) <= 1 + + def test_token_pruning_stops_at_single_message(self): + """Pruning stops when only 1 message remains (to prevent empty list).""" + conv = _make_conversation() + conv.app = MagicMock() + + # Always over limit + mi = MagicMock() + mi.get_llm_num_tokens.return_value = 99999 + + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + msg = _make_message() + msg.parent_message_id = None + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + else: + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages(max_token_limit=1) + + # At least 1 message should remain + assert len(result) >= 1 + + def test_no_pruning_when_within_limit(self): + """When tokens ≤ limit, no pruning occurs.""" + mem = self._make_memory() + mem.model_instance.get_llm_num_tokens.return_value = 50 # well under default 2000 + + msg = _make_message() + msg.parent_message_id = None + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + else: + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages(max_token_limit=2000) + + assert len(result) == 2 # user + assistant + + def test_plain_user_and_assistant_messages_returned(self): + """Without files, plain UserPromptMessage and AssistantPromptMessage appear.""" + mem = self._make_memory() + + msg = _make_message(answer="My answer") + msg.query = "My query" + msg.parent_message_id = None + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + else: + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages() + + assert len(result) == 2 + user_msg, ai_msg = result + assert isinstance(user_msg, UserPromptMessage) + assert user_msg.content == "My query" + assert isinstance(ai_msg, AssistantPromptMessage) + assert ai_msg.content == "My answer" + + +# =========================================================================== +# Tests for get_history_prompt_text +# =========================================================================== + + +class TestGetHistoryPromptText: + """Tests for get_history_prompt_text.""" + + def _make_memory(self) -> TokenBufferMemory: + conv = _make_conversation() + conv.app = MagicMock() + return TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + def test_empty_messages_returns_empty_string(self): + mem = self._make_memory() + with patch.object(mem, "get_history_prompt_messages", return_value=[]): + result = mem.get_history_prompt_text() + assert result == "" + + def test_user_and_assistant_messages_formatted(self): + mem = self._make_memory() + messages = [ + UserPromptMessage(content="Hello"), + AssistantPromptMessage(content="World"), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text(human_prefix="H", ai_prefix="A") + assert result == "H: Hello\nA: World" + + def test_custom_prefixes_applied(self): + mem = self._make_memory() + messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Bye"), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text(human_prefix="Human", ai_prefix="Bot") + assert "Human: Hi" in result + assert "Bot: Bye" in result + + def test_list_content_with_text_and_image(self): + """List content: TextPromptMessageContent → text; ImagePromptMessageContent → [image].""" + mem = self._make_memory() + messages = [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="caption"), + ImagePromptMessageContent(url="http://img", format="png", mime_type="image/png"), + ] + ), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + assert "caption" in result + assert "[image]" in result + + def test_list_content_text_only(self): + mem = self._make_memory() + messages = [ + UserPromptMessage(content=[TextPromptMessageContent(data="just text")]), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + assert "just text" in result + + def test_list_content_image_only(self): + mem = self._make_memory() + messages = [ + UserPromptMessage( + content=[ + ImagePromptMessageContent(url="http://img", format="jpg", mime_type="image/jpeg"), + ] + ), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + assert "[image]" in result + + def test_unknown_role_skipped(self): + """Messages with a role that is not USER or ASSISTANT are skipped.""" + mem = self._make_memory() + + # Create a mock message with a SYSTEM role + system_msg = MagicMock() + system_msg.role = PromptMessageRole.SYSTEM + system_msg.content = "system instruction" + + user_msg = UserPromptMessage(content="hi") + + with patch.object(mem, "get_history_prompt_messages", return_value=[system_msg, user_msg]): + result = mem.get_history_prompt_text() + + assert "system instruction" not in result + assert "Human: hi" in result + + def test_passes_max_token_limit_and_message_limit(self): + """Parameters are forwarded to get_history_prompt_messages.""" + mem = self._make_memory() + with patch.object(mem, "get_history_prompt_messages", return_value=[]) as mock_get: + mem.get_history_prompt_text(max_token_limit=500, message_limit=10) + mock_get.assert_called_once_with(max_token_limit=500, message_limit=10) + + def test_multiple_messages_joined_by_newline(self): + mem = self._make_memory() + messages = [ + UserPromptMessage(content="Q1"), + AssistantPromptMessage(content="A1"), + UserPromptMessage(content="Q2"), + AssistantPromptMessage(content="A2"), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + lines = result.split("\n") + assert len(lines) == 4 + assert lines[0] == "Human: Q1" + assert lines[1] == "Assistant: A1" + assert lines[2] == "Human: Q2" + assert lines[3] == "Assistant: A2" + + def test_assistant_list_content_formatted(self): + """AssistantPromptMessage with list content is also handled.""" + mem = self._make_memory() + messages = [ + AssistantPromptMessage( + content=[ + TextPromptMessageContent(data="response text"), + ImagePromptMessageContent(url="http://img2", format="png", mime_type="image/png"), + ] + ), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + assert "response text" in result + assert "[image]" in result diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py b/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py new file mode 100644 index 0000000000..acb43d4036 --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py @@ -0,0 +1,326 @@ +import time +import uuid +from datetime import datetime +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.trace import SpanKind, Status, StatusCode + +from core.ops.aliyun_trace.data_exporter.traceclient import ( + INVALID_SPAN_ID, + SpanBuilder, + TraceClient, + build_endpoint, + convert_datetime_to_nanoseconds, + convert_string_to_id, + convert_to_span_id, + convert_to_trace_id, + create_link, + generate_span_id, +) +from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData + + +@pytest.fixture +def trace_client_factory(): + """Factory fixture for creating TraceClient instances with automatic cleanup.""" + clients_to_shutdown = [] + + def _factory(**kwargs): + client = TraceClient(**kwargs) + clients_to_shutdown.append(client) + return client + + yield _factory + + # Cleanup: shutdown all created clients + for client in clients_to_shutdown: + client.shutdown() + + +class TestTraceClient: + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname") + def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory): + mock_gethostname.return_value = "test-host" + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + + assert client.endpoint == "http://test-endpoint" + assert client.max_queue_size == 1000 + assert client.schedule_delay_sec == 5 + assert client.done is False + assert client.worker_thread.is_alive() + + client.shutdown() + assert client.done is True + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_export(self, mock_exporter_class, trace_client_factory): + mock_exporter = mock_exporter_class.return_value + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + spans = [MagicMock(spec=ReadableSpan)] + client.export(spans) + mock_exporter.export.assert_called_once_with(spans) + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory): + mock_response = MagicMock() + mock_response.status_code = 405 + mock_head.return_value = mock_response + + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + assert client.api_check() is True + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory): + mock_response = MagicMock() + mock_response.status_code = 500 + mock_head.return_value = mock_response + + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + assert client.api_check() is False + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory): + mock_head.side_effect = httpx.RequestError("Connection error") + + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"): + client.api_check() + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_get_project_url(self, mock_exporter_class, trace_client_factory): + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm" + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_add_span(self, mock_exporter_class, trace_client_factory): + client = trace_client_factory( + service_name="test-service", + endpoint="http://test-endpoint", + max_export_batch_size=2, + ) + + # Test add None + client.add_span(None) + assert len(client.queue) == 0 + + # Test add valid SpanData + span_data = SpanData( + name="test-span", + trace_id=123, + span_id=456, + parent_span_id=None, + start_time=1000, + end_time=2000, + status=Status(StatusCode.OK), + span_kind=SpanKind.INTERNAL, + ) + + mock_span = MagicMock(spec=ReadableSpan) + client.span_builder.build_span = MagicMock(return_value=mock_span) + + with patch.object(client.condition, "notify") as mock_notify: + client.add_span(span_data) + assert len(client.queue) == 1 + mock_notify.assert_not_called() + + client.add_span(span_data) + assert len(client.queue) == 2 + mock_notify.assert_called_once() + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") + def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory): + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1) + + span_data = SpanData( + name="test-span", + trace_id=123, + span_id=456, + parent_span_id=None, + start_time=1000, + end_time=2000, + status=Status(StatusCode.OK), + span_kind=SpanKind.INTERNAL, + ) + mock_span = MagicMock(spec=ReadableSpan) + client.span_builder.build_span = MagicMock(return_value=mock_span) + + client.add_span(span_data) + assert len(client.queue) == 1 + + client.add_span(span_data) + assert len(client.queue) == 1 + mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.") + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_export_batch_error(self, mock_exporter_class, trace_client_factory): + mock_exporter = mock_exporter_class.return_value + mock_exporter.export.side_effect = Exception("Export failed") + + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + mock_span = MagicMock(spec=ReadableSpan) + client.queue.append(mock_span) + + with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger: + client._export_batch() + mock_logger.warning.assert_called() + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_worker_loop(self, mock_exporter_class, trace_client_factory): + # We need to test the wait timeout in _worker + # But _worker runs in a thread. Let's mock condition.wait. + client = trace_client_factory( + service_name="test-service", + endpoint="http://test-endpoint", + schedule_delay_sec=0.1, + ) + + with patch.object(client.condition, "wait") as mock_wait: + # Let it run for a bit then shut down + time.sleep(0.2) + client.shutdown() + # mock_wait might have been called + assert mock_wait.called or client.done + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory): + mock_exporter = mock_exporter_class.return_value + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + + mock_span = MagicMock(spec=ReadableSpan) + client.queue.append(mock_span) + + client.shutdown() + # Should have called export twice (once in worker/export_batch, once in shutdown) + # or at least once if worker was waiting + assert mock_exporter.export.called + assert mock_exporter.shutdown.called + + +class TestSpanBuilder: + def test_build_span(self): + resource = MagicMock() + builder = SpanBuilder(resource) + + span_data = SpanData( + name="test-span", + trace_id=123, + span_id=456, + parent_span_id=789, + start_time=1000, + end_time=2000, + status=Status(StatusCode.OK), + span_kind=SpanKind.INTERNAL, + attributes={"attr1": "val1"}, + events=[], + links=[], + ) + + span = builder.build_span(span_data) + assert isinstance(span, ReadableSpan) + assert span.name == "test-span" + assert span.context.trace_id == 123 + assert span.context.span_id == 456 + assert span.parent.span_id == 789 + assert span.resource == resource + assert span.attributes == {"attr1": "val1"} + + def test_build_span_no_parent(self): + resource = MagicMock() + builder = SpanBuilder(resource) + + span_data = SpanData( + name="test-span", + trace_id=123, + span_id=456, + parent_span_id=None, + start_time=1000, + end_time=2000, + status=Status(StatusCode.OK), + span_kind=SpanKind.INTERNAL, + ) + + span = builder.build_span(span_data) + assert span.parent is None + + +def test_create_link(): + trace_id_str = "0123456789abcdef0123456789abcdef" + link = create_link(trace_id_str) + assert link.context.trace_id == int(trace_id_str, 16) + assert link.context.span_id == INVALID_SPAN_ID + + with pytest.raises(ValueError, match="Invalid trace ID format"): + create_link("invalid-hex") + + +def test_generate_span_id(): + # Test normal generation + span_id = generate_span_id() + assert isinstance(span_id, int) + assert span_id != INVALID_SPAN_ID + + # Test retry loop + with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand: + mock_rand.side_effect = [INVALID_SPAN_ID, 999] + span_id = generate_span_id() + assert span_id == 999 + assert mock_rand.call_count == 2 + + +def test_convert_to_trace_id(): + uid = str(uuid.uuid4()) + trace_id = convert_to_trace_id(uid) + assert trace_id == uuid.UUID(uid).int + + with pytest.raises(ValueError, match="UUID cannot be None"): + convert_to_trace_id(None) + + with pytest.raises(ValueError, match="Invalid UUID input"): + convert_to_trace_id("not-a-uuid") + + +def test_convert_string_to_id(): + assert convert_string_to_id("test") > 0 + # Test with None string + with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen: + mock_gen.return_value = 12345 + assert convert_string_to_id(None) == 12345 + + +def test_convert_to_span_id(): + uid = str(uuid.uuid4()) + span_id = convert_to_span_id(uid, "test-type") + assert isinstance(span_id, int) + + with pytest.raises(ValueError, match="UUID cannot be None"): + convert_to_span_id(None, "test") + + with pytest.raises(ValueError, match="Invalid UUID input"): + convert_to_span_id("not-a-uuid", "test") + + +def test_convert_datetime_to_nanoseconds(): + dt = datetime(2023, 1, 1, 12, 0, 0) + ns = convert_datetime_to_nanoseconds(dt) + assert ns == int(dt.timestamp() * 1e9) + assert convert_datetime_to_nanoseconds(None) is None + + +def test_build_endpoint(): + license_key = "abc" + + # CMS 2.0 endpoint + url1 = "https://log.aliyuncs.com" + assert build_endpoint(url1, license_key) == "https://log.aliyuncs.com/adapt_abc/api/v1/traces" + + # XTrace endpoint + url2 = "https://example.com" + assert build_endpoint(url2, license_key) == "https://example.com/adapt_abc/api/otlp/traces" diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py b/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py new file mode 100644 index 0000000000..2fcb927e0c --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py @@ -0,0 +1,88 @@ +import pytest +from opentelemetry import trace as trace_api +from opentelemetry.sdk.trace import Event +from opentelemetry.trace import SpanKind, Status, StatusCode +from pydantic import ValidationError + +from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata + + +class TestTraceMetadata: + def test_trace_metadata_init(self): + links = [trace_api.Link(context=trace_api.SpanContext(0, 0, False))] + metadata = TraceMetadata( + trace_id=123, workflow_span_id=456, session_id="session_1", user_id="user_1", links=links + ) + assert metadata.trace_id == 123 + assert metadata.workflow_span_id == 456 + assert metadata.session_id == "session_1" + assert metadata.user_id == "user_1" + assert metadata.links == links + + +class TestSpanData: + def test_span_data_init_required_fields(self): + span_data = SpanData(trace_id=123, span_id=456, name="test_span", start_time=1000, end_time=2000) + assert span_data.trace_id == 123 + assert span_data.span_id == 456 + assert span_data.name == "test_span" + assert span_data.start_time == 1000 + assert span_data.end_time == 2000 + + # Check defaults + assert span_data.parent_span_id is None + assert span_data.attributes == {} + assert span_data.events == [] + assert span_data.links == [] + assert span_data.status.status_code == StatusCode.UNSET + assert span_data.span_kind == SpanKind.INTERNAL + + def test_span_data_with_optional_fields(self): + event = Event(name="event_1", timestamp=1500) + link = trace_api.Link(context=trace_api.SpanContext(0, 0, False)) + status = Status(StatusCode.OK) + + span_data = SpanData( + trace_id=123, + parent_span_id=111, + span_id=456, + name="test_span", + attributes={"key": "value"}, + events=[event], + links=[link], + status=status, + start_time=1000, + end_time=2000, + span_kind=SpanKind.SERVER, + ) + + assert span_data.parent_span_id == 111 + assert span_data.attributes == {"key": "value"} + assert span_data.events == [event] + assert span_data.links == [link] + assert span_data.status.status_code == status.status_code + assert span_data.span_kind == SpanKind.SERVER + + def test_span_data_missing_required_fields(self): + with pytest.raises(ValidationError): + SpanData( + trace_id=123, + # span_id missing + name="test_span", + start_time=1000, + end_time=2000, + ) + + def test_span_data_arbitrary_types_allowed(self): + # opentelemetry.trace.Status and Event are "arbitrary types" for Pydantic + # This test ensures they are accepted thanks to model_config + status = Status(StatusCode.ERROR, description="error occurred") + event = Event(name="exception", timestamp=1234, attributes={"exception.type": "ValueError"}) + + span_data = SpanData( + trace_id=123, span_id=456, name="test_span", status=status, events=[event], start_time=1000, end_time=2000 + ) + + assert span_data.status.status_code == status.status_code + assert span_data.status.description == status.description + assert span_data.events == [event] diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py b/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py new file mode 100644 index 0000000000..3961555b9a --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py @@ -0,0 +1,68 @@ +from core.ops.aliyun_trace.entities.semconv import ( + ACS_ARMS_SERVICE_FEATURE, + GEN_AI_COMPLETION, + GEN_AI_FRAMEWORK, + GEN_AI_INPUT_MESSAGE, + GEN_AI_OUTPUT_MESSAGE, + GEN_AI_PROMPT, + GEN_AI_PROVIDER_NAME, + GEN_AI_REQUEST_MODEL, + GEN_AI_RESPONSE_FINISH_REASON, + GEN_AI_SESSION_ID, + GEN_AI_SPAN_KIND, + GEN_AI_USAGE_INPUT_TOKENS, + GEN_AI_USAGE_OUTPUT_TOKENS, + GEN_AI_USAGE_TOTAL_TOKENS, + GEN_AI_USER_ID, + GEN_AI_USER_NAME, + INPUT_VALUE, + OUTPUT_VALUE, + RETRIEVAL_DOCUMENT, + RETRIEVAL_QUERY, + TOOL_DESCRIPTION, + TOOL_NAME, + TOOL_PARAMETERS, + GenAISpanKind, +) + + +def test_constants(): + assert ACS_ARMS_SERVICE_FEATURE == "acs.arms.service.feature" + assert GEN_AI_SESSION_ID == "gen_ai.session.id" + assert GEN_AI_USER_ID == "gen_ai.user.id" + assert GEN_AI_USER_NAME == "gen_ai.user.name" + assert GEN_AI_SPAN_KIND == "gen_ai.span.kind" + assert GEN_AI_FRAMEWORK == "gen_ai.framework" + assert INPUT_VALUE == "input.value" + assert OUTPUT_VALUE == "output.value" + assert RETRIEVAL_QUERY == "retrieval.query" + assert RETRIEVAL_DOCUMENT == "retrieval.document" + assert GEN_AI_REQUEST_MODEL == "gen_ai.request.model" + assert GEN_AI_PROVIDER_NAME == "gen_ai.provider.name" + assert GEN_AI_USAGE_INPUT_TOKENS == "gen_ai.usage.input_tokens" + assert GEN_AI_USAGE_OUTPUT_TOKENS == "gen_ai.usage.output_tokens" + assert GEN_AI_USAGE_TOTAL_TOKENS == "gen_ai.usage.total_tokens" + assert GEN_AI_PROMPT == "gen_ai.prompt" + assert GEN_AI_COMPLETION == "gen_ai.completion" + assert GEN_AI_RESPONSE_FINISH_REASON == "gen_ai.response.finish_reason" + assert GEN_AI_INPUT_MESSAGE == "gen_ai.input.messages" + assert GEN_AI_OUTPUT_MESSAGE == "gen_ai.output.messages" + assert TOOL_NAME == "tool.name" + assert TOOL_DESCRIPTION == "tool.description" + assert TOOL_PARAMETERS == "tool.parameters" + + +def test_gen_ai_span_kind_enum(): + assert GenAISpanKind.CHAIN == "CHAIN" + assert GenAISpanKind.RETRIEVER == "RETRIEVER" + assert GenAISpanKind.RERANKER == "RERANKER" + assert GenAISpanKind.LLM == "LLM" + assert GenAISpanKind.EMBEDDING == "EMBEDDING" + assert GenAISpanKind.TOOL == "TOOL" + assert GenAISpanKind.AGENT == "AGENT" + assert GenAISpanKind.TASK == "TASK" + + # Verify iteration works (covers the class definition) + kinds = list(GenAISpanKind) + assert len(kinds) == 8 + assert "LLM" in kinds diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py new file mode 100644 index 0000000000..fac0597f5a --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py @@ -0,0 +1,647 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags + +import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module +from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace +from core.ops.aliyun_trace.entities.semconv import ( + GEN_AI_COMPLETION, + GEN_AI_INPUT_MESSAGE, + GEN_AI_OUTPUT_MESSAGE, + GEN_AI_PROMPT, + GEN_AI_REQUEST_MODEL, + GEN_AI_RESPONSE_FINISH_REASON, + GEN_AI_USAGE_TOTAL_TOKENS, + RETRIEVAL_DOCUMENT, + RETRIEVAL_QUERY, + TOOL_DESCRIPTION, + TOOL_NAME, + TOOL_PARAMETERS, + GenAISpanKind, +) +from core.ops.entities.config_entity import AliyunConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey + + +class RecordingTraceClient: + def __init__(self, service_name: str = "service", endpoint: str = "endpoint"): + self.service_name = service_name + self.endpoint = endpoint + self.added_spans: list[object] = [] + + def add_span(self, span) -> None: + self.added_spans.append(span) + + def api_check(self) -> bool: + return True + + def get_project_url(self) -> str: + return "project-url" + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +def _make_link(trace_id: int = 1, span_id: int = 2) -> Link: + context = SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=False, + trace_flags=TraceFlags.SAMPLED, + ) + return Link(context) + + +def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: + defaults = { + "workflow_id": "workflow-id", + "tenant_id": "tenant-id", + "workflow_run_id": "00000000-0000-0000-0000-000000000001", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"sys.query": "hello"}, + "workflow_run_outputs": {"answer": "world"}, + "workflow_run_version": "v1", + "total_tokens": 1, + "file_list": [], + "query": "hello", + "metadata": {"conversation_id": "conv", "user_id": "u", "app_id": "app"}, + "message_id": None, + "start_time": _dt(), + "end_time": _dt(), + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return WorkflowTraceInfo(**defaults) + + +def _make_message_trace_info(**overrides) -> MessageTraceInfo: + defaults = { + "conversation_model": "chat", + "message_tokens": 1, + "answer_tokens": 2, + "total_tokens": 3, + "conversation_mode": "chat", + "metadata": {"conversation_id": "conv", "ls_model_name": "m", "ls_provider": "p"}, + "message_id": "00000000-0000-0000-0000-000000000002", + "message_data": SimpleNamespace(from_account_id="acc", from_end_user_id=None), + "inputs": {"prompt": "hi"}, + "outputs": "ok", + "start_time": _dt(), + "end_time": _dt(), + "error": None, + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return MessageTraceInfo(**defaults) + + +def _make_dataset_retrieval_trace_info(**overrides) -> DatasetRetrievalTraceInfo: + defaults = { + "metadata": {"conversation_id": "conv", "user_id": "u"}, + "message_id": "00000000-0000-0000-0000-000000000003", + "message_data": SimpleNamespace(), + "inputs": "q", + "documents": [SimpleNamespace()], + "start_time": _dt(), + "end_time": _dt(), + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return DatasetRetrievalTraceInfo(**defaults) + + +def _make_tool_trace_info(**overrides) -> ToolTraceInfo: + defaults = { + "tool_name": "tool", + "tool_inputs": {"x": 1}, + "tool_outputs": "out", + "tool_config": {"desc": "d"}, + "tool_parameters": {}, + "time_cost": 0.1, + "metadata": {"conversation_id": "conv", "user_id": "u"}, + "message_id": "00000000-0000-0000-0000-000000000004", + "message_data": SimpleNamespace(), + "inputs": {"i": "v"}, + "outputs": {"o": "v"}, + "start_time": _dt(), + "end_time": _dt(), + "error": None, + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return ToolTraceInfo(**defaults) + + +def _make_suggested_question_trace_info(**overrides) -> SuggestedQuestionTraceInfo: + defaults = { + "suggested_question": ["q1", "q2"], + "level": "info", + "total_tokens": 1, + "metadata": {"conversation_id": "conv", "user_id": "u", "ls_model_name": "m", "ls_provider": "p"}, + "message_id": "00000000-0000-0000-0000-000000000005", + "inputs": {"i": 1}, + "start_time": _dt(), + "end_time": _dt(), + "error": None, + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return SuggestedQuestionTraceInfo(**defaults) + + +@pytest.fixture +def trace_instance(monkeypatch: pytest.MonkeyPatch) -> AliyunDataTrace: + monkeypatch.setattr(aliyun_trace_module, "build_endpoint", lambda base_url, license_key: "built-endpoint") + monkeypatch.setattr(aliyun_trace_module, "TraceClient", RecordingTraceClient) + # Mock get_service_account_with_tenant to avoid DB errors + monkeypatch.setattr(AliyunDataTrace, "get_service_account_with_tenant", lambda self, app_id: MagicMock()) + + config = AliyunConfig(app_name="app", license_key="k", endpoint="https://example.com") + trace = AliyunDataTrace(config) + return trace + + +def test_init_builds_endpoint_and_client(monkeypatch: pytest.MonkeyPatch): + build_endpoint = MagicMock(return_value="built") + trace_client_cls = MagicMock() + monkeypatch.setattr(aliyun_trace_module, "build_endpoint", build_endpoint) + monkeypatch.setattr(aliyun_trace_module, "TraceClient", trace_client_cls) + + config = AliyunConfig(app_name="my-app", license_key="license", endpoint="https://example.com") + trace = AliyunDataTrace(config) + + build_endpoint.assert_called_once_with("https://example.com", "license") + trace_client_cls.assert_called_once_with(service_name="my-app", endpoint="built") + assert trace.trace_config == config + + +def test_trace_dispatches_to_correct_methods(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + workflow_trace = MagicMock() + message_trace = MagicMock() + suggested_question_trace = MagicMock() + dataset_retrieval_trace = MagicMock() + tool_trace = MagicMock() + monkeypatch.setattr(trace_instance, "workflow_trace", workflow_trace) + monkeypatch.setattr(trace_instance, "message_trace", message_trace) + monkeypatch.setattr(trace_instance, "suggested_question_trace", suggested_question_trace) + monkeypatch.setattr(trace_instance, "dataset_retrieval_trace", dataset_retrieval_trace) + monkeypatch.setattr(trace_instance, "tool_trace", tool_trace) + + trace_instance.trace(_make_workflow_trace_info()) + workflow_trace.assert_called_once() + + trace_instance.trace(_make_message_trace_info()) + message_trace.assert_called_once() + + trace_instance.trace(_make_suggested_question_trace_info()) + suggested_question_trace.assert_called_once() + + trace_instance.trace(_make_dataset_retrieval_trace_info()) + dataset_retrieval_trace.assert_called_once() + + trace_instance.trace(_make_tool_trace_info()) + tool_trace.assert_called_once() + + # Branches that do nothing but should be covered + trace_instance.trace(ModerationTraceInfo(flagged=False, action="allow", preset_response="", query="", metadata={})) + trace_instance.trace(GenerateNameTraceInfo(tenant_id="t", metadata={})) + + +def test_api_check_delegates(trace_instance: AliyunDataTrace): + trace_instance.trace_client.api_check = MagicMock(return_value=False) + assert trace_instance.api_check() is False + + +def test_get_project_url_success(trace_instance: AliyunDataTrace): + assert trace_instance.get_project_url() == "project-url" + + +def test_get_project_url_error(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(trace_instance.trace_client, "get_project_url", MagicMock(side_effect=Exception("boom"))) + logger_mock = MagicMock() + monkeypatch.setattr(aliyun_trace_module, "logger", logger_mock) + + with pytest.raises(ValueError, match=r"Aliyun get project url failed: boom"): + trace_instance.get_project_url() + logger_mock.info.assert_called() + + +def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 111) + monkeypatch.setattr( + aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"workflow": 222}.get(span_type, 0) + ) + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + + add_workflow_span = MagicMock() + get_workflow_node_executions = MagicMock(return_value=[MagicMock(), MagicMock()]) + build_workflow_node_span = MagicMock(side_effect=["span-1", "span-2"]) + monkeypatch.setattr(trace_instance, "add_workflow_span", add_workflow_span) + monkeypatch.setattr(trace_instance, "get_workflow_node_executions", get_workflow_node_executions) + monkeypatch.setattr(trace_instance, "build_workflow_node_span", build_workflow_node_span) + + trace_info = _make_workflow_trace_info( + trace_id="abcd", metadata={"conversation_id": "c", "user_id": "u", "app_id": "app"} + ) + trace_instance.workflow_trace(trace_info) + + add_workflow_span.assert_called_once() + passed_trace_metadata = add_workflow_span.call_args.args[1] + assert passed_trace_metadata.trace_id == 111 + assert passed_trace_metadata.workflow_span_id == 222 + assert passed_trace_metadata.session_id == "c" + assert passed_trace_metadata.user_id == "u" + assert passed_trace_metadata.links == [] + + assert trace_instance.trace_client.added_spans == ["span-1", "span-2"] + + +def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): + trace_info = _make_message_trace_info(message_data=None) + trace_instance.message_trace(trace_info) + assert trace_instance.trace_client.added_spans == [] + + +def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10) + monkeypatch.setattr( + aliyun_trace_module, + "convert_to_span_id", + lambda _, span_type: {"message": 20, "llm": 30}.get(span_type, 0), + ) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + monkeypatch.setattr(aliyun_trace_module, "get_user_id_from_message_data", lambda _: "user") + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) + + trace_info = _make_message_trace_info( + metadata={"conversation_id": "conv", "ls_model_name": "model", "ls_provider": "provider"}, + message_tokens=7, + answer_tokens=11, + total_tokens=18, + outputs="completion", + ) + trace_instance.message_trace(trace_info) + + assert len(trace_instance.trace_client.added_spans) == 2 + message_span, llm_span = trace_instance.trace_client.added_spans + + assert message_span.name == "message" + assert message_span.trace_id == 10 + assert message_span.parent_span_id is None + assert message_span.span_id == 20 + assert message_span.span_kind == SpanKind.SERVER + assert message_span.status == status + assert message_span.attributes["gen_ai.span.kind"] == GenAISpanKind.CHAIN + + assert llm_span.name == "llm" + assert llm_span.parent_span_id == 20 + assert llm_span.span_id == 30 + assert llm_span.status == status + assert llm_span.attributes[GEN_AI_REQUEST_MODEL] == "model" + assert llm_span.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "18" + + +def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): + trace_info = _make_dataset_retrieval_trace_info(message_data=None) + trace_instance.dataset_retrieval_trace(trace_info) + assert trace_instance.trace_client.added_spans == [] + + +def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 1) + monkeypatch.setattr( + aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 2}.get(span_type, 0) + ) + monkeypatch.setattr(aliyun_trace_module, "generate_span_id", lambda: 3) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}]) + + trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query")) + assert len(trace_instance.trace_client.added_spans) == 1 + span = trace_instance.trace_client.added_spans[0] + assert span.name == "dataset_retrieval" + assert span.attributes[RETRIEVAL_QUERY] == "query" + assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]' + + +def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): + trace_info = _make_tool_trace_info(message_data=None) + trace_instance.tool_trace(trace_info) + assert trace_instance.trace_client.added_spans == [] + + +def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10) + monkeypatch.setattr( + aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 20}.get(span_type, 0) + ) + monkeypatch.setattr(aliyun_trace_module, "generate_span_id", lambda: 30) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) + + trace_instance.tool_trace( + _make_tool_trace_info( + tool_name="my-tool", + tool_inputs={"a": 1}, + tool_config={"description": "x"}, + inputs={"i": 1}, + ) + ) + + assert len(trace_instance.trace_client.added_spans) == 1 + span = trace_instance.trace_client.added_spans[0] + assert span.name == "my-tool" + assert span.status == status + assert span.attributes[TOOL_NAME] == "my-tool" + assert span.attributes[TOOL_DESCRIPTION] == '{"description": "x"}' + + +def test_get_workflow_node_executions_requires_app_id(trace_instance: AliyunDataTrace): + trace_info = _make_workflow_trace_info(metadata={"conversation_id": "c"}) + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.get_workflow_node_executions(trace_info) + + +def test_get_workflow_node_executions_builds_repo_and_fetches( + trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch +): + trace_info = _make_workflow_trace_info(metadata={"app_id": "app", "conversation_id": "c", "user_id": "u"}) + + account = object() + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", MagicMock(return_value=account)) + monkeypatch.setattr(aliyun_trace_module, "sessionmaker", MagicMock()) + monkeypatch.setattr(aliyun_trace_module, "db", SimpleNamespace(engine="engine")) + + repo = MagicMock() + repo.get_by_workflow_run.return_value = ["node1"] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr(aliyun_trace_module, "DifyCoreRepositoryFactory", mock_factory) + + result = trace_instance.get_workflow_node_executions(trace_info) + assert result == ["node1"] + repo.get_by_workflow_run.assert_called_once_with(workflow_run_id=trace_info.workflow_run_id) + + +def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm")) + + node_execution.node_type = NodeType.LLM + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "llm" + + +def test_build_workflow_node_span_routes_knowledge_retrieval_type( + trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch +): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval")) + + node_execution.node_type = NodeType.KNOWLEDGE_RETRIEVAL + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "retrieval" + + +def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool")) + + node_execution.node_type = NodeType.TOOL + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "tool" + + +def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task")) + + node_execution.node_type = NodeType.CODE + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "task" + + +def test_build_workflow_node_span_handles_errors( + trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom"))) + node_execution.node_type = NodeType.CODE + + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) is None + assert "Error occurred in build_workflow_node_span" in caplog.text + + +def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node-id" + node_execution.title = "title" + node_execution.inputs = {"a": 1} + node_execution.outputs = {"b": 2} + node_execution.created_at = _dt() + node_execution.finished_at = _dt() + + span = trace_instance.build_workflow_task_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span.trace_id == 1 + assert span.span_id == 9 + assert span.status.status_code == StatusCode.OK + assert span.attributes["gen_ai.span.kind"] == GenAISpanKind.TASK + + +def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()]) + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node-id" + node_execution.title = "my-tool" + node_execution.inputs = {"a": 1} + node_execution.outputs = {"b": 2} + node_execution.created_at = _dt() + node_execution.finished_at = _dt() + node_execution.metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"k": "v"}} + + span = trace_instance.build_workflow_tool_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span.attributes[TOOL_NAME] == "my-tool" + assert span.attributes[TOOL_DESCRIPTION] == '{"k": "v"}' + assert span.attributes[TOOL_PARAMETERS] == '{"a": 1}' + assert span.status.status_code == StatusCode.OK + + # Cover metadata is None and inputs is None + node_execution.metadata = None + node_execution.inputs = None + span2 = trace_instance.build_workflow_tool_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span2.attributes[TOOL_DESCRIPTION] == "{}" + assert span2.attributes[TOOL_PARAMETERS] == "{}" + + +def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) + monkeypatch.setattr( + aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else [] + ) + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node-id" + node_execution.title = "retrieval" + node_execution.inputs = {"query": "q"} + node_execution.outputs = {"result": [{"doc": "d"}]} + node_execution.created_at = _dt() + node_execution.finished_at = _dt() + + span = trace_instance.build_workflow_retrieval_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span.attributes[RETRIEVAL_QUERY] == "q" + assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"formatted": true}]' + + # Cover empty inputs/outputs + node_execution.inputs = None + node_execution.outputs = None + span2 = trace_instance.build_workflow_retrieval_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span2.attributes[RETRIEVAL_QUERY] == "" + assert span2.attributes[RETRIEVAL_DOCUMENT] == "[]" + + +def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) + monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in") + monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out") + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node-id" + node_execution.title = "llm" + node_execution.process_data = { + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + "prompts": ["p"], + "model_name": "m", + "model_provider": "p1", + } + node_execution.outputs = {"text": "t", "finish_reason": "stop"} + node_execution.created_at = _dt() + node_execution.finished_at = _dt() + + span = trace_instance.build_workflow_llm_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "3" + assert span.attributes[GEN_AI_REQUEST_MODEL] == "m" + assert span.attributes[GEN_AI_PROMPT] == '["p"]' + assert span.attributes[GEN_AI_COMPLETION] == "t" + assert span.attributes[GEN_AI_RESPONSE_FINISH_REASON] == "stop" + assert span.attributes[GEN_AI_INPUT_MESSAGE] == "in" + assert span.attributes[GEN_AI_OUTPUT_MESSAGE] == "out" + + # Cover usage from outputs if not in process_data + node_execution.process_data = {"prompts": []} + node_execution.outputs = {"usage": {"total_tokens": 10}, "text": ""} + span2 = trace_instance.build_workflow_llm_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span2.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "10" + + +def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 20}.get(span_type, 0) + ) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + + # CASE 1: With message_id + trace_info = _make_workflow_trace_info( + message_id="msg-1", workflow_run_inputs={"sys.query": "hi"}, workflow_run_outputs={"ans": "ok"} + ) + trace_instance.add_workflow_span(trace_info, trace_metadata) + + assert len(trace_instance.trace_client.added_spans) == 2 + message_span = trace_instance.trace_client.added_spans[0] + workflow_span = trace_instance.trace_client.added_spans[1] + + assert message_span.name == "message" + assert message_span.span_kind == SpanKind.SERVER + assert message_span.parent_span_id is None + + assert workflow_span.name == "workflow" + assert workflow_span.span_kind == SpanKind.INTERNAL + assert workflow_span.parent_span_id == 20 + + trace_instance.trace_client.added_spans.clear() + + # CASE 2: Without message_id + trace_info_no_msg = _make_workflow_trace_info(message_id=None) + trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata) + assert len(trace_instance.trace_client.added_spans) == 1 + span = trace_instance.trace_client.added_spans[0] + assert span.name == "workflow" + assert span.span_kind == SpanKind.SERVER + assert span.parent_span_id is None + + +def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10) + monkeypatch.setattr( + aliyun_trace_module, + "convert_to_span_id", + lambda _, span_type: {"message": 20, "suggested_question": 21}.get(span_type, 0), + ) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) + + trace_info = _make_suggested_question_trace_info(suggested_question=["how?"]) + trace_instance.suggested_question_trace(trace_info) + + assert len(trace_instance.trace_client.added_spans) == 1 + span = trace_instance.trace_client.added_spans[0] + assert span.name == "suggested_question" + assert span.attributes[GEN_AI_COMPLETION] == '["how?"]' diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py new file mode 100644 index 0000000000..763fc90710 --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py @@ -0,0 +1,275 @@ +import json +from unittest.mock import MagicMock + +from opentelemetry.trace import Link, StatusCode + +from core.ops.aliyun_trace.entities.semconv import ( + GEN_AI_FRAMEWORK, + GEN_AI_SESSION_ID, + GEN_AI_SPAN_KIND, + GEN_AI_USER_ID, + INPUT_VALUE, + OUTPUT_VALUE, +) +from core.ops.aliyun_trace.utils import ( + create_common_span_attributes, + create_links_from_trace_id, + create_status_from_error, + extract_retrieval_documents, + format_input_messages, + format_output_messages, + format_retrieval_documents, + get_user_id_from_message_data, + get_workflow_node_status, + serialize_json_data, +) +from core.rag.models.document import Document +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import WorkflowNodeExecutionStatus +from models import EndUser + + +def test_get_user_id_from_message_data_no_end_user(monkeypatch): + message_data = MagicMock() + message_data.from_account_id = "account_id" + message_data.from_end_user_id = None + + assert get_user_id_from_message_data(message_data) == "account_id" + + +def test_get_user_id_from_message_data_with_end_user(monkeypatch): + message_data = MagicMock() + message_data.from_account_id = "account_id" + message_data.from_end_user_id = "end_user_id" + + end_user_data = MagicMock(spec=EndUser) + end_user_data.session_id = "session_id" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = end_user_data + + mock_session = MagicMock() + mock_session.query.return_value = mock_query + + from core.ops.aliyun_trace.utils import db + + monkeypatch.setattr(db, "session", mock_session) + + assert get_user_id_from_message_data(message_data) == "session_id" + + +def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): + message_data = MagicMock() + message_data.from_account_id = "account_id" + message_data.from_end_user_id = "end_user_id" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = None + + mock_session = MagicMock() + mock_session.query.return_value = mock_query + + from core.ops.aliyun_trace.utils import db + + monkeypatch.setattr(db, "session", mock_session) + + assert get_user_id_from_message_data(message_data) == "account_id" + + +def test_create_status_from_error(): + # Case OK + status_ok = create_status_from_error(None) + assert status_ok.status_code == StatusCode.OK + + # Case Error + status_err = create_status_from_error("some error") + assert status_err.status_code == StatusCode.ERROR + assert status_err.description == "some error" + + +def test_get_workflow_node_status(): + node_execution = MagicMock(spec=WorkflowNodeExecution) + + # SUCCEEDED + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + status = get_workflow_node_status(node_execution) + assert status.status_code == StatusCode.OK + + # FAILED + node_execution.status = WorkflowNodeExecutionStatus.FAILED + node_execution.error = "node fail" + status = get_workflow_node_status(node_execution) + assert status.status_code == StatusCode.ERROR + assert status.description == "node fail" + + # EXCEPTION + node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION + node_execution.error = "node exception" + status = get_workflow_node_status(node_execution) + assert status.status_code == StatusCode.ERROR + assert status.description == "node exception" + + # UNSET/OTHER + node_execution.status = WorkflowNodeExecutionStatus.RUNNING + status = get_workflow_node_status(node_execution) + assert status.status_code == StatusCode.UNSET + + +def test_create_links_from_trace_id(monkeypatch): + # Mock create_link + mock_link = MagicMock(spec=Link) + import core.ops.aliyun_trace.data_exporter.traceclient + + monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) + + # Trace ID None + assert create_links_from_trace_id(None) == [] + + # Trace ID Present + links = create_links_from_trace_id("trace_id") + assert len(links) == 1 + assert links[0] == mock_link + + +def test_extract_retrieval_documents(): + doc1 = MagicMock(spec=Document) + doc1.page_content = "content1" + doc1.metadata = {"dataset_id": "ds1", "doc_id": "di1", "document_id": "dd1", "score": 0.9} + + doc2 = MagicMock(spec=Document) + doc2.page_content = "content2" + doc2.metadata = {"dataset_id": "ds2"} # Missing some keys + + documents = [doc1, doc2] + extracted = extract_retrieval_documents(documents) + + assert len(extracted) == 2 + assert extracted[0]["content"] == "content1" + assert extracted[0]["metadata"]["dataset_id"] == "ds1" + assert extracted[0]["score"] == 0.9 + + assert extracted[1]["content"] == "content2" + assert extracted[1]["metadata"]["dataset_id"] == "ds2" + assert extracted[1]["metadata"]["doc_id"] is None + assert extracted[1]["score"] is None + + +def test_serialize_json_data(): + data = {"a": 1} + # Test ensure_ascii default (False) + assert serialize_json_data(data) == json.dumps(data, ensure_ascii=False) + # Test ensure_ascii True + assert serialize_json_data(data, ensure_ascii=True) == json.dumps(data, ensure_ascii=True) + + +def test_create_common_span_attributes(): + attrs = create_common_span_attributes( + session_id="s1", user_id="u1", span_kind="kind1", framework="fw1", inputs="in1", outputs="out1" + ) + assert attrs[GEN_AI_SESSION_ID] == "s1" + assert attrs[GEN_AI_USER_ID] == "u1" + assert attrs[GEN_AI_SPAN_KIND] == "kind1" + assert attrs[GEN_AI_FRAMEWORK] == "fw1" + assert attrs[INPUT_VALUE] == "in1" + assert attrs[OUTPUT_VALUE] == "out1" + + +def test_format_retrieval_documents(): + # Not a list + assert format_retrieval_documents("not a list") == [] + + # Valid list + docs = [ + {"metadata": {"score": 0.8, "document_id": "doc1", "source": "src1"}, "content": "c1", "title": "t1"}, + { + "metadata": {"_source": "src2", "doc_metadata": {"extra": "val"}}, + "content": "c2", + # Missing title + }, + "not a dict", # Should be skipped + ] + formatted = format_retrieval_documents(docs) + + assert len(formatted) == 2 + assert formatted[0]["document"]["content"] == "c1" + assert formatted[0]["document"]["metadata"]["title"] == "t1" + assert formatted[0]["document"]["metadata"]["source"] == "src1" + assert formatted[0]["document"]["score"] == 0.8 + assert formatted[0]["document"]["id"] == "doc1" + + assert formatted[1]["document"]["content"] == "c2" + assert formatted[1]["document"]["metadata"]["source"] == "src2" + assert formatted[1]["document"]["metadata"]["extra"] == "val" + assert "title" not in formatted[1]["document"]["metadata"] + assert formatted[1]["document"]["score"] == 0.0 # Default + + # Exception handling + # We can trigger an exception by passing something that causes an error in the loop logic, + # but the try/except covers the whole function. + # Passing a list that contains something that throws when calling .get() - though dicts won't. + # Let's mock a dict that raises on get. + class BadDict: + def get(self, *args, **kwargs): + raise Exception("boom") + + assert format_retrieval_documents([BadDict()]) == [] + + +def test_format_input_messages(): + # Not a dict + assert format_input_messages(None) == serialize_json_data([]) + + # No prompts + assert format_input_messages({}) == serialize_json_data([]) + + # Valid prompts + process_data = { + "prompts": [ + {"role": "user", "text": "hello"}, + {"role": "assistant", "text": "hi"}, + {"role": "system", "text": "be helpful"}, + {"role": "tool", "text": "result"}, + {"role": "invalid", "text": "skip me"}, + "not a dict", + {"role": "user", "text": ""}, # Empty text, should be skipped? Code says `if text: message = ...` + ] + } + result = format_input_messages(process_data) + result_list = json.loads(result) + + assert len(result_list) == 4 + assert result_list[0]["role"] == "user" + assert result_list[0]["parts"][0]["content"] == "hello" + assert result_list[1]["role"] == "assistant" + assert result_list[2]["role"] == "system" + assert result_list[3]["role"] == "tool" + + # Exception path + assert format_input_messages({"prompts": [None]}) == serialize_json_data([]) + + +def test_format_output_messages(): + # Not a dict + assert format_output_messages(None) == serialize_json_data([]) + + # No text + assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([]) + + # Valid + outputs = {"text": "done", "finish_reason": "length"} + result = format_output_messages(outputs) + result_list = json.loads(result) + assert len(result_list) == 1 + assert result_list[0]["role"] == "assistant" + assert result_list[0]["parts"][0]["content"] == "done" + assert result_list[0]["finish_reason"] == "length" + + # Invalid finish reason + outputs2 = {"text": "done", "finish_reason": "unknown"} + result2 = format_output_messages(outputs2) + result_list2 = json.loads(result2) + assert result_list2[0]["finish_reason"] == "stop" + + # Exception path + # Trigger exception in serialize_json_data by passing non-serializable + assert format_output_messages({"text": MagicMock()}) == serialize_json_data([]) diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py new file mode 100644 index 0000000000..1cee2f5b68 --- /dev/null +++ b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -0,0 +1,398 @@ +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest +from opentelemetry.sdk.trace import Tracer +from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes +from opentelemetry.trace import StatusCode + +from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( + ArizePhoenixDataTrace, + datetime_to_nanos, + error_to_string, + safe_json_dumps, + set_span_status, + setup_tracer, + wrap_span_metadata, +) +from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) + +# --- Helpers --- + + +def _dt(): + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +def _make_workflow_info(**kwargs): + defaults = { + "workflow_id": "w1", + "tenant_id": "t1", + "workflow_run_id": "r1", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"in": "val"}, + "workflow_run_outputs": {"out": "val"}, + "workflow_run_version": "1.0", + "total_tokens": 10, + "file_list": ["f1"], + "query": "hi", + "metadata": {"app_id": "app1"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(kwargs) + return WorkflowTraceInfo(**defaults) + + +def _make_message_info(**kwargs): + defaults = { + "conversation_model": "chat", + "message_tokens": 5, + "answer_tokens": 5, + "total_tokens": 10, + "conversation_mode": "chat", + "metadata": {"app_id": "app1"}, + "inputs": {"in": "val"}, + "outputs": "val", + "start_time": _dt(), + "end_time": _dt(), + "message_id": "m1", + } + defaults.update(kwargs) + return MessageTraceInfo(**defaults) + + +# --- Utility Function Tests --- + + +def test_datetime_to_nanos(): + dt = _dt() + expected = int(dt.timestamp() * 1_000_000_000) + assert datetime_to_nanos(dt) == expected + + with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt: + mock_now = MagicMock() + mock_now.timestamp.return_value = 1704110400.0 + mock_dt.now.return_value = mock_now + assert datetime_to_nanos(None) == 1704110400000000000 + + +def test_error_to_string(): + try: + raise ValueError("boom") + except ValueError as e: + err = e + + res = error_to_string(err) + assert "ValueError: boom" in res + assert "traceback" in res.lower() or "line" in res.lower() + + assert error_to_string("str error") == "str error" + assert error_to_string(None) == "Empty Stack Trace" + + +def test_set_span_status(): + span = MagicMock() + # OK + set_span_status(span, None) + span.set_status.assert_called() + assert span.set_status.call_args[0][0].status_code == StatusCode.OK + + # Error Exception + span.reset_mock() + set_span_status(span, ValueError("fail")) + assert span.set_status.call_args[0][0].status_code == StatusCode.ERROR + span.record_exception.assert_called() + + # Error String + span.reset_mock() + set_span_status(span, "fail-str") + assert span.set_status.call_args[0][0].status_code == StatusCode.ERROR + span.add_event.assert_called() + + # repr branch + class SilentError: + def __str__(self): + return "" + + def __repr__(self): + return "SilentErrorRepr" + + span.reset_mock() + set_span_status(span, SilentError()) + assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr" + + +def test_safe_json_dumps(): + assert safe_json_dumps({"a": _dt()}) == '{"a": "2024-01-01 00:00:00+00:00"}' + + +def test_wrap_span_metadata(): + res = wrap_span_metadata({"a": 1}, b=2) + assert res == {"a": 1, "b": 2, "created_from": "Dify"} + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter") +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +def test_setup_tracer_arize(mock_provider, mock_exporter): + config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") + setup_tracer(config) + mock_exporter.assert_called_once() + assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1" + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter") +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +def test_setup_tracer_phoenix(mock_provider, mock_exporter): + config = PhoenixConfig(endpoint="http://p.com", project="p") + setup_tracer(config) + mock_exporter.assert_called_once() + assert mock_exporter.call_args[1]["endpoint"] == "http://p.com/v1/traces" + + +def test_setup_tracer_exception(): + config = ArizeConfig(endpoint="http://a.com", project="p") + with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): + with pytest.raises(Exception, match="boom"): + setup_tracer(config) + + +# --- ArizePhoenixDataTrace Class Tests --- + + +@pytest.fixture +def trace_instance(): + with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup: + mock_tracer = MagicMock(spec=Tracer) + mock_processor = MagicMock() + mock_setup.return_value = (mock_tracer, mock_processor) + config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") + return ArizePhoenixDataTrace(config) + + +def test_trace_dispatch(trace_instance): + with ( + patch.object(trace_instance, "workflow_trace") as m1, + patch.object(trace_instance, "message_trace") as m2, + patch.object(trace_instance, "moderation_trace") as m3, + patch.object(trace_instance, "suggested_question_trace") as m4, + patch.object(trace_instance, "dataset_retrieval_trace") as m5, + patch.object(trace_instance, "tool_trace") as m6, + patch.object(trace_instance, "generate_name_trace") as m7, + ): + trace_instance.trace(_make_workflow_info()) + m1.assert_called() + + trace_instance.trace(_make_message_info()) + m2.assert_called() + + trace_instance.trace(ModerationTraceInfo(flagged=True, action="a", preset_response="p", query="q", metadata={})) + m3.assert_called() + + trace_instance.trace(SuggestedQuestionTraceInfo(suggested_question=[], total_tokens=0, level="i", metadata={})) + m4.assert_called() + + trace_instance.trace(DatasetRetrievalTraceInfo(metadata={})) + m5.assert_called() + + trace_instance.trace( + ToolTraceInfo( + tool_name="t", + tool_inputs={}, + tool_outputs="o", + metadata={}, + tool_config={}, + time_cost=1, + tool_parameters={}, + ) + ) + m6.assert_called() + + trace_instance.trace(GenerateNameTraceInfo(tenant_id="t", metadata={})) + m7.assert_called() + + +def test_trace_exception(trace_instance): + with patch.object(trace_instance, "workflow_trace", side_effect=RuntimeError("fail")): + with pytest.raises(RuntimeError): + trace_instance.trace(_make_workflow_info()) + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker") +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + node1 = MagicMock() + node1.node_type = "llm" + node1.status = "succeeded" + node1.inputs = {"q": "hi"} + node1.outputs = {"a": "bye", "usage": {"total_tokens": 5}} + node1.created_at = _dt() + node1.elapsed_time = 1.0 + node1.process_data = { + "prompts": [{"role": "user", "content": "hi"}], + "model_provider": "openai", + "model_name": "gpt-4", + } + node1.metadata = {"k": "v"} + node1.title = "title" + node1.id = "n1" + node1.error = None + + repo.get_by_workflow_run.return_value = [node1] + + with patch.object(trace_instance, "get_service_account_with_tenant"): + trace_instance.workflow_trace(info) + + assert trace_instance.tracer.start_span.call_count >= 2 + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +def test_workflow_trace_no_app_id(mock_db, trace_instance): + mock_db.engine = MagicMock() + info = _make_workflow_info() + info.metadata = {} + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(info) + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +def test_message_trace_success(mock_db, trace_instance): + mock_db.engine = MagicMock() + info = _make_message_info() + info.message_data = MagicMock() + info.message_data.from_account_id = "acc1" + info.message_data.from_end_user_id = None + info.message_data.query = "q" + info.message_data.answer = "a" + info.message_data.status = "s" + info.message_data.model_id = "m" + info.message_data.model_provider = "p" + info.message_data.message_metadata = "{}" + info.message_data.error = None + info.error = None + + trace_instance.message_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +def test_message_trace_with_error(mock_db, trace_instance): + mock_db.engine = MagicMock() + info = _make_message_info() + info.message_data = MagicMock() + info.message_data.from_account_id = "acc1" + info.message_data.from_end_user_id = None + info.message_data.query = "q" + info.message_data.answer = "a" + info.message_data.status = "s" + info.message_data.model_id = "m" + info.message_data.model_provider = "p" + info.message_data.message_metadata = "{}" + info.message_data.error = "processing failed" + info.error = "message error" + + trace_instance.message_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_trace_methods_return_early_with_no_message_data(trace_instance): + info = MagicMock() + info.message_data = None + + trace_instance.moderation_trace(info) + trace_instance.suggested_question_trace(info) + trace_instance.dataset_retrieval_trace(info) + trace_instance.tool_trace(info) + trace_instance.generate_name_trace(info) + + assert trace_instance.tracer.start_span.call_count == 0 + + +def test_moderation_trace_ok(trace_instance): + info = ModerationTraceInfo(flagged=True, action="a", preset_response="p", query="q", metadata={}) + info.message_data = MagicMock() + info.message_data.error = None + trace_instance.moderation_trace(info) + # root span (1) + moderation span (1) = 2 + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_suggested_question_trace_ok(trace_instance): + info = SuggestedQuestionTraceInfo(suggested_question=["?"], total_tokens=1, level="i", metadata={}) + info.message_data = MagicMock() + info.error = None + trace_instance.suggested_question_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_dataset_retrieval_trace_ok(trace_instance): + info = DatasetRetrievalTraceInfo(documents=[], metadata={}) + info.message_data = MagicMock() + info.error = None + trace_instance.dataset_retrieval_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_tool_trace_ok(trace_instance): + info = ToolTraceInfo( + tool_name="t", tool_inputs={}, tool_outputs="o", metadata={}, tool_config={}, time_cost=1, tool_parameters={} + ) + info.message_data = MagicMock() + info.error = None + trace_instance.tool_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_generate_name_trace_ok(trace_instance): + info = GenerateNameTraceInfo(tenant_id="t", metadata={}) + info.message_data = MagicMock() + info.message_data.error = None + trace_instance.generate_name_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_get_project_url_phoenix(trace_instance): + trace_instance.arize_phoenix_config = PhoenixConfig(endpoint="http://p.com", project="p") + assert "p.com/projects/" in trace_instance.get_project_url() + + +def test_set_attribute_none_logic(trace_instance): + # Test role can be None + attrs = trace_instance._construct_llm_attributes([{"role": None, "content": "hi"}]) + assert "llm.input_messages.0.message.role" not in attrs + + # Test tool call id can be None + tool_call_none_id = {"id": None, "function": {"name": "f1"}} + attrs = trace_instance._construct_llm_attributes([{"role": "assistant", "tool_calls": [tool_call_none_id]}]) + assert "llm.input_messages.0.message.tool_calls.0.tool_call.id" not in attrs + + +def test_construct_llm_attributes_dict_branch(trace_instance): + attrs = trace_instance._construct_llm_attributes({"prompt": "hi"}) + assert '"prompt": "hi"' in attrs["llm.input_messages.0.message.content"] + assert attrs["llm.input_messages.0.message.role"] == "user" + + +def test_api_check_success(trace_instance): + assert trace_instance.api_check() is True + + +def test_ensure_root_span_basic(trace_instance): + trace_instance.ensure_root_span("tid") + assert "tid" in trace_instance.dify_trace_ids diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py new file mode 100644 index 0000000000..8e036e4b52 --- /dev/null +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -0,0 +1,698 @@ +import collections +import logging +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.ops.entities.config_entity import LangfuseConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) +from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( + LangfuseGeneration, + LangfuseSpan, + LangfuseTrace, + LevelEnum, + UnitEnum, +) +from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace +from dify_graph.enums import NodeType +from models import EndUser +from models.enums import MessageStatus + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +@pytest.fixture +def langfuse_config(): + return LangfuseConfig(public_key="pk-123", secret_key="sk-123", host="https://cloud.langfuse.com") + + +@pytest.fixture +def trace_instance(langfuse_config, monkeypatch): + # Mock Langfuse client to avoid network calls + mock_client = MagicMock() + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client) + + instance = LangFuseDataTrace(langfuse_config) + return instance + + +def test_init(langfuse_config, monkeypatch): + mock_langfuse = MagicMock() + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse) + monkeypatch.setenv("FILES_URL", "http://test.url") + + instance = LangFuseDataTrace(langfuse_config) + + mock_langfuse.assert_called_once_with( + public_key=langfuse_config.public_key, + secret_key=langfuse_config.secret_key, + host=langfuse_config.host, + ) + assert instance.file_base_url == "http://test.url" + + +def test_trace_dispatch(trace_instance, monkeypatch): + methods = [ + "workflow_trace", + "message_trace", + "moderation_trace", + "suggested_question_trace", + "dataset_retrieval_trace", + "tool_trace", + "generate_name_trace", + ] + mocks = {method: MagicMock() for method in methods} + for method, m in mocks.items(): + monkeypatch.setattr(trace_instance, method, m) + + # WorkflowTraceInfo + info = MagicMock(spec=WorkflowTraceInfo) + trace_instance.trace(info) + mocks["workflow_trace"].assert_called_once_with(info) + + # MessageTraceInfo + info = MagicMock(spec=MessageTraceInfo) + trace_instance.trace(info) + mocks["message_trace"].assert_called_once_with(info) + + # ModerationTraceInfo + info = MagicMock(spec=ModerationTraceInfo) + trace_instance.trace(info) + mocks["moderation_trace"].assert_called_once_with(info) + + # SuggestedQuestionTraceInfo + info = MagicMock(spec=SuggestedQuestionTraceInfo) + trace_instance.trace(info) + mocks["suggested_question_trace"].assert_called_once_with(info) + + # DatasetRetrievalTraceInfo + info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_instance.trace(info) + mocks["dataset_retrieval_trace"].assert_called_once_with(info) + + # ToolTraceInfo + info = MagicMock(spec=ToolTraceInfo) + trace_instance.trace(info) + mocks["tool_trace"].assert_called_once_with(info) + + # GenerateNameTraceInfo + info = MagicMock(spec=GenerateNameTraceInfo) + trace_instance.trace(info) + mocks["generate_name_trace"].assert_called_once_with(info) + + +def test_workflow_trace_with_message_id(trace_instance, monkeypatch): + # Setup trace info + trace_info = WorkflowTraceInfo( + workflow_id="wf-1", + tenant_id="tenant-1", + workflow_run_id="run-1", + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={"input": "hi"}, + workflow_run_outputs={"output": "hello"}, + workflow_run_version="1.0", + message_id="msg-1", + conversation_id="conv-1", + total_tokens=100, + file_list=[], + query="hi", + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id="trace-1", + metadata={"app_id": "app-1", "user_id": "user-1"}, + workflow_app_log_id="log-1", + error="", + ) + + # Mock DB and Repositories + mock_session = MagicMock() + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + + # Mock node executions + node_llm = MagicMock() + node_llm.id = "node-llm" + node_llm.title = "LLM Node" + node_llm.node_type = NodeType.LLM + node_llm.status = "succeeded" + node_llm.process_data = { + "model_mode": "chat", + "model_name": "gpt-4", + "model_provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 20}, + } + node_llm.inputs = {"prompts": "p"} + node_llm.outputs = {"text": "t"} + node_llm.created_at = _dt() + node_llm.elapsed_time = 0.5 + node_llm.metadata = {"foo": "bar"} + + node_other = MagicMock() + node_other.id = "node-other" + node_other.title = "Other Node" + node_other.node_type = NodeType.CODE + node_other.status = "failed" + node_other.process_data = None + node_other.inputs = {"code": "print"} + node_other.outputs = {"result": "ok"} + node_other.created_at = None # Trigger datetime.now() branch + node_other.elapsed_time = 0.2 + node_other.metadata = None + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node_llm, node_other] + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + # Track calls to add_trace, add_span, add_generation + trace_instance.add_trace = MagicMock() + trace_instance.add_span = MagicMock() + trace_instance.add_generation = MagicMock() + + trace_instance.workflow_trace(trace_info) + + # Verify add_trace (Workflow Level) + trace_instance.add_trace.assert_called_once() + trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"] + assert trace_data.id == "trace-1" + assert trace_data.name == TraceTaskName.MESSAGE_TRACE + assert "message" in trace_data.tags + assert "workflow" in trace_data.tags + + # Verify add_span (Workflow Run Span) + assert trace_instance.add_span.call_count >= 1 + # First span should be workflow run span because message_id is present + workflow_span = trace_instance.add_span.call_args_list[0][1]["langfuse_span_data"] + assert workflow_span.id == "run-1" + assert workflow_span.name == TraceTaskName.WORKFLOW_TRACE + + # Verify Generation for LLM node + trace_instance.add_generation.assert_called_once() + gen_data = trace_instance.add_generation.call_args[1]["langfuse_generation_data"] + assert gen_data.id == "node-llm" + assert gen_data.usage.input == 10 + assert gen_data.usage.output == 20 + + # Verify normal span for Other node + # Second add_span call + other_span = trace_instance.add_span.call_args_list[1][1]["langfuse_span_data"] + assert other_span.id == "node-other" + assert other_span.level == LevelEnum.ERROR + + +def test_workflow_trace_no_message_id(trace_instance, monkeypatch): + trace_info = WorkflowTraceInfo( + workflow_id="wf-1", + tenant_id="tenant-1", + workflow_run_id="run-1", + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1.0", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt(), + trace_id=None, # Should fallback to workflow_run_id + metadata={"app_id": "app-1"}, + workflow_app_log_id="log-1", + error="", + ) + + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + repo = MagicMock() + repo.get_by_workflow_run.return_value = [] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.workflow_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"] + assert trace_data.id == "run-1" + assert trace_data.name == TraceTaskName.WORKFLOW_TRACE + + +def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): + trace_info = WorkflowTraceInfo( + workflow_id="wf-1", + tenant_id="tenant-1", + workflow_run_id="run-1", + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1.0", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt(), + metadata={}, # Missing app_id + workflow_app_log_id="log-1", + error="", + ) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(trace_info) + + +def test_message_trace_basic(trace_instance, monkeypatch): + message_data = MagicMock() + message_data.id = "msg-1" + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = None + message_data.provider_response_latency = 0.5 + message_data.conversation_id = "conv-1" + message_data.total_price = 0.01 + message_data.model_id = "gpt-4" + message_data.answer = "hello" + message_data.status = MessageStatus.NORMAL + message_data.error = None + + trace_info = MessageTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={"query": "hi"}, + outputs={"answer": "hello"}, + message_tokens=10, + answer_tokens=20, + total_tokens=30, + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id="trace-1", + metadata={"foo": "bar"}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=[], + error=None, + ) + + trace_instance.add_trace = MagicMock() + trace_instance.add_generation = MagicMock() + + trace_instance.message_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_instance.add_generation.assert_called_once() + + gen_data = trace_instance.add_generation.call_args[0][0] + assert gen_data.name == "llm" + assert gen_data.usage.total == 30 + + +def test_message_trace_with_end_user(trace_instance, monkeypatch): + message_data = MagicMock() + message_data.id = "msg-1" + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = "end-user-1" + message_data.conversation_id = "conv-1" + message_data.status = MessageStatus.NORMAL + message_data.model_id = "gpt-4" + message_data.error = "" + message_data.answer = "hello" + message_data.total_price = 0.0 + message_data.provider_response_latency = 0.1 + + trace_info = MessageTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={}, + outputs={}, + message_tokens=0, + answer_tokens=0, + total_tokens=0, + start_time=_dt(), + end_time=_dt(), + metadata={}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=[], + error=None, + ) + + # Mock DB session for EndUser lookup + mock_end_user = MagicMock(spec=EndUser) + mock_end_user.session_id = "session-id-123" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_end_user + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.query", lambda model: mock_query) + + trace_instance.add_trace = MagicMock() + trace_instance.add_generation = MagicMock() + + trace_instance.message_trace(trace_info) + + trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"] + assert trace_data.user_id == "session-id-123" + assert trace_data.metadata["user_id"] == "session-id-123" + + +def test_message_trace_none_data(trace_instance): + trace_info = SimpleNamespace(message_data=None, file_list=[], metadata={}) + trace_instance.add_trace = MagicMock() + trace_instance.message_trace(trace_info) + trace_instance.add_trace.assert_not_called() + + +def test_moderation_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + + trace_info = ModerationTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={"q": "hi"}, + action="stop", + flagged=True, + preset_response="blocked", + start_time=None, + end_time=None, + metadata={"foo": "bar"}, + trace_id="trace-1", + query="hi", + ) + + trace_instance.add_span = MagicMock() + trace_instance.moderation_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"] + assert span_data.name == TraceTaskName.MODERATION_TRACE + assert span_data.output["flagged"] is True + + +def test_suggested_question_trace(trace_instance): + message_data = MagicMock() + message_data.status = MessageStatus.NORMAL + message_data.error = None + + trace_info = SuggestedQuestionTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs="hi", + suggested_question=["q1"], + total_tokens=10, + level="info", + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_generation = MagicMock() + trace_instance.suggested_question_trace(trace_info) + + trace_instance.add_generation.assert_called_once() + gen_data = trace_instance.add_generation.call_args[1]["langfuse_generation_data"] + assert gen_data.name == TraceTaskName.SUGGESTED_QUESTION_TRACE + assert gen_data.usage.unit == UnitEnum.CHARACTERS + + +def test_dataset_retrieval_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = DatasetRetrievalTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs="query", + documents=[{"id": "doc1"}], + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_span = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"] + assert span_data.name == TraceTaskName.DATASET_RETRIEVAL_TRACE + assert span_data.output["documents"] == [{"id": "doc1"}] + + +def test_tool_trace(trace_instance): + trace_info = ToolTraceInfo( + message_id="msg-1", + message_data=MagicMock(), + inputs={}, + outputs={}, + tool_name="my_tool", + tool_inputs={"a": 1}, + tool_outputs="result_string", + time_cost=0.1, + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="trace-1", + tool_config={}, + tool_parameters={}, + error="some error", + ) + + trace_instance.add_span = MagicMock() + trace_instance.tool_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"] + assert span_data.name == "my_tool" + assert span_data.level == LevelEnum.ERROR + + +def test_generate_name_trace(trace_instance): + trace_info = GenerateNameTraceInfo( + inputs={"q": "hi"}, + outputs={"name": "new"}, + tenant_id="tenant-1", + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt(), + metadata={"m": 1}, + ) + + trace_instance.add_trace = MagicMock() + trace_instance.add_span = MagicMock() + + trace_instance.generate_name_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_instance.add_span.assert_called_once() + + trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"] + assert trace_data.name == TraceTaskName.GENERATE_NAME_TRACE + assert trace_data.user_id == "tenant-1" + + span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"] + assert span_data.trace_id == "conv-1" + + +def test_add_trace_success(trace_instance): + data = LangfuseTrace(id="t1", name="trace") + trace_instance.add_trace(data) + trace_instance.langfuse_client.trace.assert_called_once() + + +def test_add_trace_error(trace_instance): + trace_instance.langfuse_client.trace.side_effect = Exception("error") + data = LangfuseTrace(id="t1", name="trace") + with pytest.raises(ValueError, match="LangFuse Failed to create trace: error"): + trace_instance.add_trace(data) + + +def test_add_span_success(trace_instance): + data = LangfuseSpan(id="s1", name="span", trace_id="t1") + trace_instance.add_span(data) + trace_instance.langfuse_client.span.assert_called_once() + + +def test_add_span_error(trace_instance): + trace_instance.langfuse_client.span.side_effect = Exception("error") + data = LangfuseSpan(id="s1", name="span", trace_id="t1") + with pytest.raises(ValueError, match="LangFuse Failed to create span: error"): + trace_instance.add_span(data) + + +def test_update_span(trace_instance): + span = MagicMock() + data = LangfuseSpan(id="s1", name="span", trace_id="t1") + trace_instance.update_span(span, data) + span.end.assert_called_once() + + +def test_add_generation_success(trace_instance): + data = LangfuseGeneration(id="g1", name="gen", trace_id="t1") + trace_instance.add_generation(data) + trace_instance.langfuse_client.generation.assert_called_once() + + +def test_add_generation_error(trace_instance): + trace_instance.langfuse_client.generation.side_effect = Exception("error") + data = LangfuseGeneration(id="g1", name="gen", trace_id="t1") + with pytest.raises(ValueError, match="LangFuse Failed to create generation: error"): + trace_instance.add_generation(data) + + +def test_update_generation(trace_instance): + gen = MagicMock() + data = LangfuseGeneration(id="g1", name="gen", trace_id="t1") + trace_instance.update_generation(gen, data) + gen.end.assert_called_once() + + +def test_api_check_success(trace_instance): + trace_instance.langfuse_client.auth_check.return_value = True + assert trace_instance.api_check() is True + + +def test_api_check_error(trace_instance): + trace_instance.langfuse_client.auth_check.side_effect = Exception("fail") + with pytest.raises(ValueError, match="LangFuse API check failed: fail"): + trace_instance.api_check() + + +def test_get_project_key_success(trace_instance): + mock_data = MagicMock() + mock_data.id = "proj-1" + trace_instance.langfuse_client.client.projects.get.return_value = MagicMock(data=[mock_data]) + assert trace_instance.get_project_key() == "proj-1" + + +def test_get_project_key_error(trace_instance): + trace_instance.langfuse_client.client.projects.get.side_effect = Exception("fail") + with pytest.raises(ValueError, match="LangFuse get project key failed: fail"): + trace_instance.get_project_key() + + +def test_moderation_trace_none(trace_instance): + trace_info = ModerationTraceInfo( + message_id="m", + message_data=None, + inputs={}, + action="s", + flagged=False, + preset_response="", + query="", + metadata={}, + ) + trace_instance.add_span = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_suggested_question_trace_none(trace_instance): + trace_info = SuggestedQuestionTraceInfo( + message_id="m", message_data=None, inputs={}, suggested_question=[], total_tokens=0, level="i", metadata={} + ) + trace_instance.add_generation = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.add_generation.assert_not_called() + + +def test_dataset_retrieval_trace_none(trace_instance): + trace_info = DatasetRetrievalTraceInfo(message_id="m", message_data=None, inputs={}, documents=[], metadata={}) + trace_instance.add_span = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_langfuse_trace_entity_with_list_dict_input(): + # To cover lines 29-31 in langfuse_trace_entity.py + # We need to mock replace_text_with_content or just check if it works + # Actually replace_text_with_content is imported from core.ops.utils + data = LangfuseTrace(id="t1", name="n", input=[{"text": "hello"}]) + assert isinstance(data.input, list) + assert data.input[0]["content"] == "hello" + + +def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch, caplog): + # Setup trace info to trigger LLM node usage extraction + trace_info = WorkflowTraceInfo( + workflow_id="wf-1", + tenant_id="t", + workflow_run_id="r", + workflow_run_elapsed_time=1.0, + workflow_run_status="s", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="c", + start_time=_dt(), + end_time=_dt(), + metadata={"app_id": "app-1"}, + workflow_app_log_id="l", + error="", + ) + + node = MagicMock() + node.id = "n1" + node.title = "LLM Node" + node.node_type = NodeType.LLM + node.status = "succeeded" + + class BadDict(collections.UserDict): + def get(self, key, default=None): + if key == "usage": + raise Exception("Usage extraction failed") + return super().get(key, default) + + node.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]}) + node.created_at = _dt() + node.elapsed_time = 0.1 + node.metadata = {} + node.outputs = {} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.add_generation = MagicMock() + + with caplog.at_level(logging.ERROR): + trace_instance.workflow_trace(trace_info) + + assert "Failed to extract usage" in caplog.text + trace_instance.add_generation.assert_called_once() diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py new file mode 100644 index 0000000000..98f9dd00cf --- /dev/null +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -0,0 +1,608 @@ +import collections +from datetime import datetime, timedelta +from unittest.mock import MagicMock + +import pytest + +from core.ops.entities.config_entity import LangSmithConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) +from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( + LangSmithRunModel, + LangSmithRunType, + LangSmithRunUpdateModel, +) +from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey +from models import EndUser + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0) + + +@pytest.fixture +def langsmith_config(): + return LangSmithConfig(api_key="ls-123", project="default", endpoint="https://api.smith.langchain.com") + + +@pytest.fixture +def trace_instance(langsmith_config, monkeypatch): + # Mock LangSmith client + mock_client = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client) + + instance = LangSmithDataTrace(langsmith_config) + return instance + + +def test_init(langsmith_config, monkeypatch): + mock_client_class = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class) + monkeypatch.setenv("FILES_URL", "http://test.url") + + instance = LangSmithDataTrace(langsmith_config) + + mock_client_class.assert_called_once_with(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint) + assert instance.langsmith_key == langsmith_config.api_key + assert instance.project_name == langsmith_config.project + assert instance.file_base_url == "http://test.url" + + +def test_trace_dispatch(trace_instance, monkeypatch): + methods = [ + "workflow_trace", + "message_trace", + "moderation_trace", + "suggested_question_trace", + "dataset_retrieval_trace", + "tool_trace", + "generate_name_trace", + ] + mocks = {method: MagicMock() for method in methods} + for method, m in mocks.items(): + monkeypatch.setattr(trace_instance, method, m) + + # WorkflowTraceInfo + info = MagicMock(spec=WorkflowTraceInfo) + trace_instance.trace(info) + mocks["workflow_trace"].assert_called_once_with(info) + + # MessageTraceInfo + info = MagicMock(spec=MessageTraceInfo) + trace_instance.trace(info) + mocks["message_trace"].assert_called_once_with(info) + + # ModerationTraceInfo + info = MagicMock(spec=ModerationTraceInfo) + trace_instance.trace(info) + mocks["moderation_trace"].assert_called_once_with(info) + + # SuggestedQuestionTraceInfo + info = MagicMock(spec=SuggestedQuestionTraceInfo) + trace_instance.trace(info) + mocks["suggested_question_trace"].assert_called_once_with(info) + + # DatasetRetrievalTraceInfo + info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_instance.trace(info) + mocks["dataset_retrieval_trace"].assert_called_once_with(info) + + # ToolTraceInfo + info = MagicMock(spec=ToolTraceInfo) + trace_instance.trace(info) + mocks["tool_trace"].assert_called_once_with(info) + + # GenerateNameTraceInfo + info = MagicMock(spec=GenerateNameTraceInfo) + trace_instance.trace(info) + mocks["generate_name_trace"].assert_called_once_with(info) + + +def test_workflow_trace(trace_instance, monkeypatch): + # Setup trace info + workflow_data = MagicMock() + workflow_data.created_at = _dt() + workflow_data.finished_at = _dt() + timedelta(seconds=1) + + trace_info = WorkflowTraceInfo( + tenant_id="tenant-1", + workflow_id="wf-1", + workflow_run_id="run-1", + workflow_run_inputs={"input": "hi"}, + workflow_run_outputs={"output": "hello"}, + workflow_run_status="succeeded", + workflow_run_version="1.0", + workflow_run_elapsed_time=1.0, + total_tokens=100, + file_list=[], + query="hi", + message_id="msg-1", + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id="trace-1", + metadata={"app_id": "app-1"}, + workflow_app_log_id="log-1", + error="", + workflow_data=workflow_data, + ) + + # Mock dependencies + mock_session = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + + # Mock node executions + node_llm = MagicMock() + node_llm.id = "node-llm" + node_llm.title = "LLM Node" + node_llm.node_type = NodeType.LLM + node_llm.status = "succeeded" + node_llm.process_data = { + "model_mode": "chat", + "model_name": "gpt-4", + "model_provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 20}, + } + node_llm.inputs = {"prompts": "p"} + node_llm.outputs = {"text": "t"} + node_llm.created_at = _dt() + node_llm.elapsed_time = 0.5 + node_llm.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 30} + + node_other = MagicMock() + node_other.id = "node-other" + node_other.title = "Tool Node" + node_other.node_type = NodeType.TOOL + node_other.status = "succeeded" + node_other.process_data = None + node_other.inputs = {"tool_input": "val"} + node_other.outputs = {"tool_output": "val"} + node_other.created_at = None # Trigger datetime.now() + node_other.elapsed_time = 0.2 + node_other.metadata = {} + + node_retrieval = MagicMock() + node_retrieval.id = "node-retrieval" + node_retrieval.title = "Retrieval Node" + node_retrieval.node_type = NodeType.KNOWLEDGE_RETRIEVAL + node_retrieval.status = "succeeded" + node_retrieval.process_data = None + node_retrieval.inputs = {"query": "val"} + node_retrieval.outputs = {"results": "val"} + node_retrieval.created_at = _dt() + node_retrieval.elapsed_time = 0.2 + node_retrieval.metadata = {} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node_llm, node_other, node_retrieval] + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_run = MagicMock() + + trace_instance.workflow_trace(trace_info) + + # Verify add_run calls + # 1. message run (id="msg-1") + # 2. workflow run (id="run-1") + # 3. node llm run (id="node-llm") + # 4. node other run (id="node-other") + # 5. node retrieval run (id="node-retrieval") + assert trace_instance.add_run.call_count == 5 + + call_args = [call[0][0] for call in trace_instance.add_run.call_args_list] + + assert call_args[0].id == "msg-1" + assert call_args[0].name == TraceTaskName.MESSAGE_TRACE + + assert call_args[1].id == "run-1" + assert call_args[1].name == TraceTaskName.WORKFLOW_TRACE + assert call_args[1].parent_run_id == "msg-1" + + assert call_args[2].id == "node-llm" + assert call_args[2].run_type == LangSmithRunType.llm + + assert call_args[3].id == "node-other" + assert call_args[3].run_type == LangSmithRunType.tool + + assert call_args[4].id == "node-retrieval" + assert call_args[4].run_type == LangSmithRunType.retriever + + +def test_workflow_trace_no_start_time(trace_instance, monkeypatch): + workflow_data = MagicMock() + workflow_data.created_at = _dt() + workflow_data.finished_at = _dt() + timedelta(seconds=1) + + trace_info = WorkflowTraceInfo( + tenant_id="tenant-1", + workflow_id="wf-1", + workflow_run_id="run-1", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_status="succeeded", + workflow_run_version="1.0", + workflow_run_elapsed_time=1.0, + total_tokens=10, + file_list=[], + query="hi", + message_id="msg-1", + conversation_id="conv-1", + start_time=None, + end_time=None, + trace_id="trace-1", + metadata={"app_id": "app-1"}, + workflow_app_log_id="log-1", + error="", + workflow_data=workflow_data, + ) + + mock_session = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + repo = MagicMock() + repo.get_by_workflow_run.return_value = [] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_run = MagicMock() + trace_instance.workflow_trace(trace_info) + assert trace_instance.add_run.called + + +def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.trace_id = "trace-1" + trace_info.message_id = None + trace_info.workflow_run_id = "run-1" + trace_info.start_time = None + trace_info.workflow_data = MagicMock() + trace_info.workflow_data.created_at = _dt() + trace_info.metadata = {} # Empty metadata + trace_info.workflow_app_log_id = "log-1" + trace_info.file_list = [] + trace_info.total_tokens = 0 + trace_info.workflow_run_inputs = {} + trace_info.workflow_run_outputs = {} + trace_info.error = "" + + mock_session = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(trace_info) + + +def test_message_trace(trace_instance, monkeypatch): + message_data = MagicMock() + message_data.id = "msg-1" + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = "end-user-1" + message_data.answer = "hello answer" + + trace_info = MessageTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={"input": "hi"}, + outputs={"answer": "hello"}, + message_tokens=10, + answer_tokens=20, + total_tokens=30, + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id="trace-1", + metadata={"foo": "bar"}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=[], + error=None, + message_file_data=MagicMock(url="file-url"), + ) + + # Mock EndUser lookup + mock_end_user = MagicMock(spec=EndUser) + mock_end_user.session_id = "session-id-123" + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_end_user + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.query", lambda model: mock_query) + + trace_instance.add_run = MagicMock() + + trace_instance.message_trace(trace_info) + + # 1. message run + # 2. llm run + assert trace_instance.add_run.call_count == 2 + + call_args = [call[0][0] for call in trace_instance.add_run.call_args_list] + assert call_args[0].id == "msg-1" + assert call_args[0].extra["metadata"]["end_user_id"] == "session-id-123" + assert call_args[1].parent_run_id == "msg-1" + assert call_args[1].name == "llm" + + +def test_message_trace_no_data(trace_instance): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.message_data = None + trace_info.file_list = [] + trace_info.message_file_data = None + trace_info.metadata = {} + trace_instance.add_run = MagicMock() + trace_instance.message_trace(trace_info) + trace_instance.add_run.assert_not_called() + + +def test_moderation_trace_no_data(trace_instance): + trace_info = MagicMock(spec=ModerationTraceInfo) + trace_info.message_data = None + trace_instance.add_run = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.add_run.assert_not_called() + + +def test_suggested_question_trace_no_data(trace_instance): + trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) + trace_info.message_data = None + trace_instance.add_run = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.add_run.assert_not_called() + + +def test_dataset_retrieval_trace_no_data(trace_instance): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_data = None + trace_instance.add_run = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.add_run.assert_not_called() + + +def test_moderation_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = ModerationTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={"q": "hi"}, + action="stop", + flagged=True, + preset_response="blocked", + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + query="hi", + ) + + trace_instance.add_run = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.MODERATION_TRACE + + +def test_suggested_question_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = SuggestedQuestionTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs="hi", + suggested_question=["q1"], + total_tokens=10, + level="info", + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_run = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.SUGGESTED_QUESTION_TRACE + + +def test_dataset_retrieval_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = DatasetRetrievalTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs="query", + documents=[{"id": "doc1"}], + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_run = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.DATASET_RETRIEVAL_TRACE + + +def test_tool_trace(trace_instance): + trace_info = ToolTraceInfo( + message_id="msg-1", + message_data=MagicMock(), + inputs={}, + outputs={}, + tool_name="my_tool", + tool_inputs={"a": 1}, + tool_outputs="result", + time_cost=0.1, + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="trace-1", + tool_config={}, + tool_parameters={}, + file_url="http://file", + ) + + trace_instance.add_run = MagicMock() + trace_instance.tool_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == "my_tool" + + +def test_generate_name_trace(trace_instance): + trace_info = GenerateNameTraceInfo( + inputs={"q": "hi"}, + outputs={"name": "new"}, + tenant_id="tenant-1", + conversation_id="conv-1", + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_run = MagicMock() + trace_instance.generate_name_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.GENERATE_NAME_TRACE + + +def test_add_run_success(trace_instance): + run_data = LangSmithRunModel( + id="run-1", name="test", inputs={}, outputs={}, run_type=LangSmithRunType.tool, start_time=_dt() + ) + trace_instance.project_id = "proj-1" + trace_instance.add_run(run_data) + trace_instance.langsmith_client.create_run.assert_called_once() + args, kwargs = trace_instance.langsmith_client.create_run.call_args + assert kwargs["session_id"] == "proj-1" + + +def test_add_run_error(trace_instance): + run_data = LangSmithRunModel(id="run-1", name="test", run_type=LangSmithRunType.tool, start_time=_dt()) + trace_instance.langsmith_client.create_run.side_effect = Exception("failed") + with pytest.raises(ValueError, match="LangSmith Failed to create run: failed"): + trace_instance.add_run(run_data) + + +def test_update_run_success(trace_instance): + update_data = LangSmithRunUpdateModel(run_id="run-1", outputs={"out": "val"}) + trace_instance.update_run(update_data) + trace_instance.langsmith_client.update_run.assert_called_once() + + +def test_update_run_error(trace_instance): + update_data = LangSmithRunUpdateModel(run_id="run-1") + trace_instance.langsmith_client.update_run.side_effect = Exception("failed") + with pytest.raises(ValueError, match="LangSmith Failed to update run: failed"): + trace_instance.update_run(update_data) + + +def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, caplog): + workflow_data = MagicMock() + workflow_data.created_at = _dt() + workflow_data.finished_at = _dt() + timedelta(seconds=1) + + trace_info = WorkflowTraceInfo( + tenant_id="tenant-1", + workflow_id="wf-1", + workflow_run_id="run-1", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_status="succeeded", + workflow_run_version="1.0", + workflow_run_elapsed_time=1.0, + total_tokens=100, + file_list=[], + query="hi", + message_id="msg-1", + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt(), + trace_id="trace-1", + metadata={"app_id": "app-1"}, + workflow_app_log_id="log-1", + error="", + workflow_data=workflow_data, + ) + + class BadDict(collections.UserDict): + def get(self, key, default=None): + if key == "usage": + raise Exception("Usage extraction failed") + return super().get(key, default) + + node_llm = MagicMock() + node_llm.id = "node-llm" + node_llm.title = "LLM Node" + node_llm.node_type = NodeType.LLM + node_llm.status = "succeeded" + node_llm.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]}) + node_llm.inputs = {} + node_llm.outputs = {} + node_llm.created_at = _dt() + node_llm.elapsed_time = 0.5 + node_llm.metadata = {} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node_llm] + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_run = MagicMock() + + import logging + + with caplog.at_level(logging.ERROR): + trace_instance.workflow_trace(trace_info) + + assert "Failed to extract usage" in caplog.text + + +def test_api_check_success(trace_instance): + assert trace_instance.api_check() is True + assert trace_instance.langsmith_client.create_project.called + assert trace_instance.langsmith_client.delete_project.called + + +def test_api_check_error(trace_instance): + trace_instance.langsmith_client.create_project.side_effect = Exception("error") + with pytest.raises(ValueError, match="LangSmith API check failed: error"): + trace_instance.api_check() + + +def test_get_project_url_success(trace_instance): + trace_instance.langsmith_client.get_run_url.return_value = "https://smith.langchain.com/o/org/p/proj/r/run" + url = trace_instance.get_project_url() + assert url == "https://smith.langchain.com/o/org/p/proj" + + +def test_get_project_url_error(trace_instance): + trace_instance.langsmith_client.get_run_url.side_effect = Exception("error") + with pytest.raises(ValueError, match="LangSmith get run url failed: error"): + trace_instance.get_project_url() diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py new file mode 100644 index 0000000000..0657acc1d9 --- /dev/null +++ b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py @@ -0,0 +1,1019 @@ +"""Comprehensive tests for core.ops.mlflow_trace.mlflow_trace module.""" + +from __future__ import annotations + +import json +import os +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds +from dify_graph.enums import NodeType + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: + defaults = { + "workflow_id": "wf-id", + "tenant_id": "tenant", + "workflow_run_id": "run-1", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"key": "val"}, + "workflow_run_outputs": {"answer": "42"}, + "workflow_run_version": "v1", + "total_tokens": 10, + "file_list": [], + "query": "hello", + "metadata": {"user_id": "u1", "conversation_id": "c1"}, + "start_time": _dt(), + "end_time": _dt(), + } + defaults.update(overrides) + return WorkflowTraceInfo(**defaults) + + +def _make_message_trace_info(**overrides) -> MessageTraceInfo: + defaults = { + "conversation_model": "chat", + "message_tokens": 5, + "answer_tokens": 10, + "total_tokens": 15, + "conversation_mode": "chat", + "metadata": {"conversation_id": "c1", "from_account_id": "a1"}, + "message_id": "msg-1", + "message_data": SimpleNamespace( + model_provider="openai", + model_id="gpt-4", + total_price=0.01, + answer="response text", + ), + "inputs": {"prompt": "hi"}, + "outputs": "ok", + "start_time": _dt(), + "end_time": _dt(), + "error": None, + } + defaults.update(overrides) + return MessageTraceInfo(**defaults) + + +def _make_tool_trace_info(**overrides) -> ToolTraceInfo: + defaults = { + "tool_name": "my_tool", + "tool_inputs": {"x": 1}, + "tool_outputs": "output", + "tool_config": {"desc": "d"}, + "tool_parameters": {"p": "v"}, + "time_cost": 0.5, + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "inputs": {"i": "v"}, + "outputs": {"o": "v"}, + "start_time": _dt(), + "end_time": _dt(), + "error": None, + } + defaults.update(overrides) + return ToolTraceInfo(**defaults) + + +def _make_moderation_trace_info(**overrides) -> ModerationTraceInfo: + defaults = { + "flagged": False, + "action": "allow", + "preset_response": "", + "query": "test", + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + } + defaults.update(overrides) + return ModerationTraceInfo(**defaults) + + +def _make_dataset_retrieval_trace_info(**overrides) -> DatasetRetrievalTraceInfo: + defaults = { + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "message_data": SimpleNamespace(), + "inputs": "query", + "documents": [{"content": "doc"}], + "start_time": _dt(), + "end_time": _dt(), + } + defaults.update(overrides) + return DatasetRetrievalTraceInfo(**defaults) + + +def _make_suggested_question_trace_info(**overrides) -> SuggestedQuestionTraceInfo: + defaults = { + "suggested_question": ["q1", "q2"], + "level": "info", + "total_tokens": 5, + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "message_data": SimpleNamespace(created_at=_dt(), updated_at=_dt()), + "inputs": {"i": 1}, + "start_time": _dt(), + "end_time": _dt(), + "error": None, + } + defaults.update(overrides) + return SuggestedQuestionTraceInfo(**defaults) + + +def _make_generate_name_trace_info(**overrides) -> GenerateNameTraceInfo: + defaults = { + "tenant_id": "t1", + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "inputs": {"i": 1}, + "outputs": {"name": "test"}, + "start_time": _dt(), + "end_time": _dt(), + } + defaults.update(overrides) + return GenerateNameTraceInfo(**defaults) + + +def _make_node(**overrides): + """Create a mock workflow node execution row.""" + defaults = { + "id": "node-1", + "tenant_id": "t1", + "app_id": "app-1", + "title": "Node Title", + "node_type": NodeType.CODE, + "status": "succeeded", + "inputs": '{"key": "value"}', + "outputs": '{"result": "ok"}', + "created_at": _dt(), + "elapsed_time": 1.0, + "process_data": None, + "execution_metadata": None, + } + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +# ── Fixtures ───────────────────────────────────────────────────────────────── + + +@pytest.fixture +def mock_mlflow(): + with patch("core.ops.mlflow_trace.mlflow_trace.mlflow") as mock: + yield mock + + +@pytest.fixture +def mock_tracing(): + """Patch all MLflow tracing functions used by the module.""" + with ( + patch("core.ops.mlflow_trace.mlflow_trace.start_span_no_context") as mock_start, + patch("core.ops.mlflow_trace.mlflow_trace.update_current_trace") as mock_update, + patch("core.ops.mlflow_trace.mlflow_trace.set_span_in_context") as mock_set, + patch("core.ops.mlflow_trace.mlflow_trace.detach_span_from_context") as mock_detach, + ): + yield { + "start": mock_start, + "update": mock_update, + "set": mock_set, + "detach": mock_detach, + } + + +@pytest.fixture +def mock_db(): + with patch("core.ops.mlflow_trace.mlflow_trace.db") as mock: + yield mock + + +@pytest.fixture +def trace_instance(mock_mlflow): + """Create an MLflowDataTrace using a basic MLflowConfig (no auth).""" + config = MLflowConfig(tracking_uri="http://localhost:5000", experiment_id="0") + return MLflowDataTrace(config) + + +# ── datetime_to_nanoseconds ───────────────────────────────────────────────── + + +class TestDatetimeToNanoseconds: + def test_none_returns_none(self): + assert datetime_to_nanoseconds(None) is None + + def test_converts_datetime(self): + dt = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + expected = int(dt.timestamp() * 1_000_000_000) + assert datetime_to_nanoseconds(dt) == expected + + +# ── __init__ / setup ───────────────────────────────────────────────────────── + + +class TestInit: + def test_mlflow_config_no_auth(self, mock_mlflow): + config = MLflowConfig(tracking_uri="http://localhost:5000", experiment_id="0") + trace = MLflowDataTrace(config) + mock_mlflow.set_tracking_uri.assert_called_with("http://localhost:5000") + mock_mlflow.set_experiment.assert_called_with(experiment_id="0") + assert trace.get_project_url() == "http://localhost:5000/#/experiments/0/traces" + assert os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] == "true" + + def test_mlflow_config_with_auth(self, mock_mlflow): + config = MLflowConfig( + tracking_uri="http://localhost:5000", + experiment_id="1", + username="user", + password="pass", + ) + MLflowDataTrace(config) + assert os.environ["MLFLOW_TRACKING_USERNAME"] == "user" + assert os.environ["MLFLOW_TRACKING_PASSWORD"] == "pass" + + def test_databricks_oauth(self, mock_mlflow): + config = DatabricksConfig( + host="https://db.com/", + experiment_id="42", + client_id="cid", + client_secret="csec", + ) + trace = MLflowDataTrace(config) + assert os.environ["DATABRICKS_HOST"] == "https://db.com/" + assert os.environ["DATABRICKS_CLIENT_ID"] == "cid" + assert os.environ["DATABRICKS_CLIENT_SECRET"] == "csec" + mock_mlflow.set_tracking_uri.assert_called_with("databricks") + # Trailing slash stripped + assert trace.get_project_url() == "https://db.com/ml/experiments/42/traces" + + def test_databricks_pat(self, mock_mlflow): + config = DatabricksConfig( + host="https://db.com", + experiment_id="1", + personal_access_token="pat", + ) + trace = MLflowDataTrace(config) + assert os.environ["DATABRICKS_TOKEN"] == "pat" + assert "db.com/ml/experiments/1/traces" in trace.get_project_url() + + def test_databricks_no_creds_raises(self, mock_mlflow): + config = DatabricksConfig(host="https://db.com", experiment_id="1") + with pytest.raises(ValueError, match="Either Databricks token"): + MLflowDataTrace(config) + + +# ── trace dispatcher ──────────────────────────────────────────────────────── + + +class TestTraceDispatcher: + def test_dispatches_workflow(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "workflow_trace") as mock_wt: + trace_instance.trace(_make_workflow_trace_info()) + mock_wt.assert_called_once() + + def test_dispatches_message(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "message_trace") as mock_mt: + trace_instance.trace(_make_message_trace_info()) + mock_mt.assert_called_once() + + def test_dispatches_tool(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "tool_trace") as mock_tt: + trace_instance.trace(_make_tool_trace_info()) + mock_tt.assert_called_once() + + def test_dispatches_moderation(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "moderation_trace") as mock_mod: + trace_instance.trace(_make_moderation_trace_info(message_data=SimpleNamespace(created_at=_dt()))) + mock_mod.assert_called_once() + + def test_dispatches_dataset_retrieval(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "dataset_retrieval_trace") as mock_dr: + trace_instance.trace(_make_dataset_retrieval_trace_info()) + mock_dr.assert_called_once() + + def test_dispatches_suggested_question(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "suggested_question_trace") as mock_sq: + trace_instance.trace(_make_suggested_question_trace_info()) + mock_sq.assert_called_once() + + def test_dispatches_generate_name(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "generate_name_trace") as mock_gn: + trace_instance.trace(_make_generate_name_trace_info()) + mock_gn.assert_called_once() + + def test_reraises_exception(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "workflow_trace", side_effect=RuntimeError("boom")): + with pytest.raises(RuntimeError, match="boom"): + trace_instance.trace(_make_workflow_trace_info()) + + +# ── workflow_trace ─────────────────────────────────────────────────────────── + + +class TestWorkflowTrace: + def test_basic_workflow_no_nodes(self, trace_instance, mock_tracing, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info(conversation_id="sess-1") + trace_instance.workflow_trace(trace_info) + + # Workflow span started and ended + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + def test_workflow_filters_sys_inputs_and_adds_query(self, trace_instance, mock_tracing, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info( + workflow_run_inputs={"sys.app_id": "x", "user_input": "hi"}, + query="hello", + ) + trace_instance.workflow_trace(trace_info) + + call_kwargs = mock_tracing["start"].call_args + inputs = call_kwargs.kwargs["inputs"] + assert "sys.app_id" not in inputs + assert inputs["user_input"] == "hi" + assert inputs["query"] == "hello" + + def test_workflow_with_llm_node(self, trace_instance, mock_tracing, mock_db): + llm_node = _make_node( + node_type=NodeType.LLM, + process_data=json.dumps( + { + "prompts": [{"role": "user", "text": "hi"}], + "model_name": "gpt-4", + "model_provider": "openai", + "finish_reason": "stop", + "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}, + } + ), + outputs='{"text": "hello world"}', + ) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [llm_node] + + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + assert mock_tracing["start"].call_count == 2 + node_span.end.assert_called_once() + workflow_span.end.assert_called_once() + + def test_workflow_with_question_classifier_node(self, trace_instance, mock_tracing, mock_db): + qc_node = _make_node( + node_type=NodeType.QUESTION_CLASSIFIER, + process_data=json.dumps( + { + "prompts": "classify this", + "model_name": "gpt-4", + "model_provider": "openai", + } + ), + ) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [qc_node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + assert mock_tracing["start"].call_count == 2 + + def test_workflow_with_http_request_node(self, trace_instance, mock_tracing, mock_db): + http_node = _make_node( + node_type=NodeType.HTTP_REQUEST, + process_data='{"url": "https://api.com"}', + ) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [http_node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + # HTTP_REQUEST uses process_data as inputs + node_start_call = mock_tracing["start"].call_args_list[1] + assert node_start_call.kwargs["inputs"] == '{"url": "https://api.com"}' + + def test_workflow_with_knowledge_retrieval_node(self, trace_instance, mock_tracing, mock_db): + kr_node = _make_node( + node_type=NodeType.KNOWLEDGE_RETRIEVAL, + outputs=json.dumps( + { + "result": [ + {"content": "doc1", "metadata": {"source": "s1"}}, + {"content": "doc2", "metadata": {}}, + ] + } + ), + ) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [kr_node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + # outputs should be parsed to Document objects + end_call = node_span.end.call_args + outputs = end_call.kwargs["outputs"] + assert len(outputs) == 2 + + def test_workflow_with_failed_node(self, trace_instance, mock_tracing, mock_db): + failed_node = _make_node(status="failed") + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [failed_node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + node_span.set_status.assert_called_once() + node_span.add_event.assert_called_once() + + def test_workflow_with_workflow_error(self, trace_instance, mock_tracing, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + workflow_span = MagicMock() + mock_tracing["start"].return_value = workflow_span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info(error="workflow failed") + trace_instance.workflow_trace(trace_info) + workflow_span.set_status.assert_called_once() + workflow_span.add_event.assert_called_once() + # Still ends the span via finally + workflow_span.end.assert_called_once() + + def test_workflow_node_no_inputs_no_outputs(self, trace_instance, mock_tracing, mock_db): + node = _make_node(inputs=None, outputs=None) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + node_call = mock_tracing["start"].call_args_list[1] + assert node_call.kwargs["inputs"] == {} + end_call = node_span.end.call_args + assert end_call.kwargs["outputs"] == {} + + def test_workflow_no_user_id_no_conversation_id(self, trace_instance, mock_tracing, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info( + metadata={}, + conversation_id=None, + ) + trace_instance.workflow_trace(trace_info) + # _set_trace_metadata still called with empty metadata + mock_tracing["update"].assert_called_once() + + def test_workflow_empty_query(self, trace_instance, mock_tracing, mock_db): + """When query is empty string, it's falsy so no query key added.""" + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info(query="") + trace_instance.workflow_trace(trace_info) + call_kwargs = mock_tracing["start"].call_args + inputs = call_kwargs.kwargs["inputs"] + assert "query" not in inputs + + +# ── _parse_llm_inputs_and_attributes ───────────────────────────────────────── + + +class TestParseLlmInputsAndAttributes: + def test_none_process_data(self, trace_instance): + node = _make_node(process_data=None) + inputs, attrs = trace_instance._parse_llm_inputs_and_attributes(node) + assert inputs == {} + assert attrs == {} + + def test_invalid_json(self, trace_instance): + node = _make_node(process_data="not json") + inputs, attrs = trace_instance._parse_llm_inputs_and_attributes(node) + assert inputs == {} + assert attrs == {} + + def test_valid_process_data_with_usage(self, trace_instance): + node = _make_node( + process_data=json.dumps( + { + "prompts": [{"role": "user", "text": "hi"}], + "model_name": "gpt-4", + "model_provider": "openai", + "finish_reason": "stop", + "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}, + } + ) + ) + inputs, attrs = trace_instance._parse_llm_inputs_and_attributes(node) + assert isinstance(inputs, list) + assert attrs["model_name"] == "gpt-4" + assert "usage" in attrs + + def test_valid_process_data_without_usage(self, trace_instance): + node = _make_node( + process_data=json.dumps( + { + "prompts": "simple prompt", + "model_name": "gpt-3.5", + } + ) + ) + inputs, attrs = trace_instance._parse_llm_inputs_and_attributes(node) + assert inputs == "simple prompt" + assert attrs["model_name"] == "gpt-3.5" + + +# ── _parse_knowledge_retrieval_outputs ─────────────────────────────────────── + + +class TestParseKnowledgeRetrievalOutputs: + def test_with_results(self, trace_instance): + outputs = {"result": [{"content": "c1", "metadata": {"s": "1"}}]} + docs = trace_instance._parse_knowledge_retrieval_outputs(outputs) + assert len(docs) == 1 + assert docs[0].page_content == "c1" + + def test_empty_result(self, trace_instance): + outputs = {"result": []} + result = trace_instance._parse_knowledge_retrieval_outputs(outputs) + assert result == outputs + + def test_no_result_key(self, trace_instance): + outputs = {"other": "data"} + result = trace_instance._parse_knowledge_retrieval_outputs(outputs) + assert result == outputs + + def test_result_not_list(self, trace_instance): + outputs = {"result": "not a list"} + result = trace_instance._parse_knowledge_retrieval_outputs(outputs) + assert result == outputs + + +# ── message_trace ──────────────────────────────────────────────────────────── + + +class TestMessageTrace: + def test_returns_early_if_no_message_data(self, trace_instance, mock_tracing, mock_db): + trace_info = _make_message_trace_info(message_data=None) + trace_instance.message_trace(trace_info) + mock_tracing["start"].assert_not_called() + + def test_basic_message_trace(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + + trace_instance.message_trace(_make_message_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + def test_message_trace_with_error(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + + trace_info = _make_message_trace_info(error="something broke") + trace_instance.message_trace(trace_info) + span.set_status.assert_called_once() + span.add_event.assert_called_once() + + def test_message_trace_with_file_data(self, trace_instance, mock_tracing, mock_db, monkeypatch): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setenv("FILES_URL", "http://files.test") + + file_data = SimpleNamespace(url="path/to/file.png") + trace_info = _make_message_trace_info( + message_file_data=file_data, + file_list=["existing_file.txt"], + ) + trace_instance.message_trace(trace_info) + call_kwargs = mock_tracing["start"].call_args + attrs = call_kwargs.kwargs["attributes"] + assert "http://files.test/path/to/file.png" in attrs["file_list"] + assert "existing_file.txt" in attrs["file_list"] + + def test_message_trace_file_list_none(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + + trace_info = _make_message_trace_info(file_list=None, message_file_data=None) + trace_instance.message_trace(trace_info) + mock_tracing["start"].assert_called_once() + + def test_message_trace_with_end_user(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + end_user = MagicMock() + end_user.session_id = "session-xyz" + mock_db.session.query.return_value.where.return_value.first.return_value = end_user + + trace_info = _make_message_trace_info( + metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"}, + ) + trace_instance.message_trace(trace_info) + # update_current_trace called with user id from EndUser + mock_tracing["update"].assert_called_once() + + def test_message_trace_with_no_conversation_id(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + + trace_info = _make_message_trace_info( + metadata={"from_account_id": "acc-1"}, + ) + trace_instance.message_trace(trace_info) + mock_tracing["update"].assert_called_once() + + +# ── _get_message_user_id ───────────────────────────────────────────────────── + + +class TestGetMessageUserId: + def test_returns_end_user_session_id(self, trace_instance, mock_db): + end_user = MagicMock() + end_user.session_id = "session-1" + mock_db.session.query.return_value.where.return_value.first.return_value = end_user + result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1"}) + assert result == "session-1" + + def test_returns_account_id_when_no_end_user(self, trace_instance, mock_db): + mock_db.session.query.return_value.where.return_value.first.return_value = None + result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1", "from_account_id": "acc-1"}) + assert result == "acc-1" + + def test_returns_account_id_when_no_end_user_id(self, trace_instance, mock_db): + result = trace_instance._get_message_user_id({"from_account_id": "acc-1"}) + assert result == "acc-1" + + def test_returns_none_when_nothing(self, trace_instance, mock_db): + result = trace_instance._get_message_user_id({}) + assert result is None + + +# ── tool_trace ─────────────────────────────────────────────────────────────── + + +class TestToolTrace: + def test_basic_tool_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.tool_trace(_make_tool_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + span.set_status.assert_not_called() + + def test_tool_trace_with_error(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.tool_trace(_make_tool_trace_info(error="tool failed")) + span.set_status.assert_called_once() + span.add_event.assert_called_once() + span.end.assert_called_once() + + +# ── moderation_trace ───────────────────────────────────────────────────────── + + +class TestModerationTrace: + def test_returns_early_if_no_message_data(self, trace_instance, mock_tracing): + trace_info = _make_moderation_trace_info(message_data=None) + trace_instance.moderation_trace(trace_info) + mock_tracing["start"].assert_not_called() + + def test_basic_moderation_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_info = _make_moderation_trace_info( + message_data=SimpleNamespace(created_at=_dt()), + start_time=_dt(), + end_time=_dt(), + ) + trace_instance.moderation_trace(trace_info) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + end_kwargs = span.end.call_args.kwargs["outputs"] + assert end_kwargs["action"] == "allow" + assert end_kwargs["flagged"] is False + + def test_moderation_uses_message_data_created_at_if_no_start_time(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_info = _make_moderation_trace_info( + message_data=SimpleNamespace(created_at=_dt()), + start_time=None, + end_time=_dt(), + ) + trace_instance.moderation_trace(trace_info) + mock_tracing["start"].assert_called_once() + + +# ── dataset_retrieval_trace ────────────────────────────────────────────────── + + +class TestDatasetRetrievalTrace: + def test_returns_early_if_no_message_data(self, trace_instance, mock_tracing): + trace_info = _make_dataset_retrieval_trace_info(message_data=None) + trace_instance.dataset_retrieval_trace(trace_info) + mock_tracing["start"].assert_not_called() + + def test_basic_dataset_retrieval_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + +# ── suggested_question_trace ───────────────────────────────────────────────── + + +class TestSuggestedQuestionTrace: + def test_returns_early_if_no_message_data(self, trace_instance, mock_tracing): + trace_info = _make_suggested_question_trace_info(message_data=None) + trace_instance.suggested_question_trace(trace_info) + mock_tracing["start"].assert_not_called() + + def test_basic_suggested_question_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.suggested_question_trace(_make_suggested_question_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + def test_suggested_question_with_error(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_info = _make_suggested_question_trace_info(error="failed") + trace_instance.suggested_question_trace(trace_info) + span.set_status.assert_called_once() + span.add_event.assert_called_once() + + def test_uses_message_data_times_when_no_start_end(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_info = _make_suggested_question_trace_info( + start_time=None, + end_time=None, + ) + trace_instance.suggested_question_trace(trace_info) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + +# ── generate_name_trace ────────────────────────────────────────────────────── + + +class TestGenerateNameTrace: + def test_basic_generate_name_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.generate_name_trace(_make_generate_name_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + +# ── _get_workflow_nodes ────────────────────────────────────────────────────── + + +class TestGetWorkflowNodes: + def test_queries_db(self, trace_instance, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = ["n1", "n2"] + result = trace_instance._get_workflow_nodes("run-1") + assert result == ["n1", "n2"] + + +# ── _get_node_span_type ───────────────────────────────────────────────────── + + +class TestGetNodeSpanType: + @pytest.mark.parametrize( + ("node_type", "expected_contains"), + [ + (NodeType.LLM, "LLM"), + (NodeType.QUESTION_CLASSIFIER, "LLM"), + (NodeType.KNOWLEDGE_RETRIEVAL, "RETRIEVER"), + (NodeType.TOOL, "TOOL"), + (NodeType.CODE, "TOOL"), + (NodeType.HTTP_REQUEST, "TOOL"), + (NodeType.AGENT, "AGENT"), + ], + ) + def test_mapped_types(self, trace_instance, node_type, expected_contains): + result = trace_instance._get_node_span_type(node_type) + assert expected_contains in str(result) + + def test_unknown_type_returns_chain(self, trace_instance): + result = trace_instance._get_node_span_type("unknown_node") + assert result == "CHAIN" + + +# ── _set_trace_metadata ───────────────────────────────────────────────────── + + +class TestSetTraceMetadata: + def test_sets_and_detaches(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["set"].return_value = "token" + + trace_instance._set_trace_metadata(span, {"key": "val"}) + mock_tracing["set"].assert_called_once_with(span) + mock_tracing["update"].assert_called_once_with(metadata={"key": "val"}) + mock_tracing["detach"].assert_called_once_with("token") + + def test_detaches_even_on_error(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["set"].return_value = "token" + mock_tracing["update"].side_effect = RuntimeError("fail") + + with pytest.raises(RuntimeError): + trace_instance._set_trace_metadata(span, {}) + mock_tracing["detach"].assert_called_once_with("token") + + def test_no_detach_when_token_is_none(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["set"].return_value = None + + trace_instance._set_trace_metadata(span, {}) + mock_tracing["detach"].assert_not_called() + + +# ── _parse_prompts ─────────────────────────────────────────────────────────── + + +class TestParsePrompts: + def test_string_input(self, trace_instance): + assert trace_instance._parse_prompts("hello") == "hello" + + def test_dict_input(self, trace_instance): + result = trace_instance._parse_prompts({"role": "user", "text": "hi"}) + assert result == {"role": "user", "content": "hi"} + + def test_list_input(self, trace_instance): + prompts = [ + {"role": "user", "text": "hi"}, + {"role": "assistant", "text": "hello"}, + ] + result = trace_instance._parse_prompts(prompts) + assert len(result) == 2 + assert result[0]["role"] == "user" + + def test_none_input(self, trace_instance): + assert trace_instance._parse_prompts(None) is None + + def test_int_passthrough(self, trace_instance): + assert trace_instance._parse_prompts(42) == 42 + + +# ── _parse_single_message ─────────────────────────────────────────────────── + + +class TestParseSingleMessage: + def test_basic_message(self, trace_instance): + result = trace_instance._parse_single_message({"role": "user", "text": "hello"}) + assert result == {"role": "user", "content": "hello"} + + def test_default_role(self, trace_instance): + result = trace_instance._parse_single_message({"text": "hello"}) + assert result["role"] == "user" + + def test_with_tool_calls(self, trace_instance): + item = { + "role": "assistant", + "text": "", + "tool_calls": [{"id": "tc1", "function": {"name": "fn"}}], + } + result = trace_instance._parse_single_message(item) + assert "tool_calls" in result + + def test_tool_role_ignores_tool_calls(self, trace_instance): + item = { + "role": "tool", + "text": "result", + "tool_calls": [{"id": "tc1"}], + } + result = trace_instance._parse_single_message(item) + assert "tool_calls" not in result + + def test_with_files(self, trace_instance): + item = {"role": "user", "text": "look", "files": ["f1.png"]} + result = trace_instance._parse_single_message(item) + assert result["files"] == ["f1.png"] + + def test_no_files(self, trace_instance): + result = trace_instance._parse_single_message({"role": "user", "text": "hi"}) + assert "files" not in result + + +# ── _resolve_tool_call_ids ─────────────────────────────────────────────────── + + +class TestResolveToolCallIds: + def test_resolves_tool_call_ids(self, trace_instance): + messages = [ + { + "role": "assistant", + "content": "", + "tool_calls": [{"id": "tc1"}, {"id": "tc2"}], + }, + {"role": "tool", "content": "result1"}, + {"role": "tool", "content": "result2"}, + ] + result = trace_instance._resolve_tool_call_ids(messages) + assert result[1]["tool_call_id"] == "tc1" + assert result[2]["tool_call_id"] == "tc2" + + def test_no_tool_calls(self, trace_instance): + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + result = trace_instance._resolve_tool_call_ids(messages) + assert "tool_call_id" not in result[0] + assert "tool_call_id" not in result[1] + + def test_tool_message_no_ids_available(self, trace_instance): + """Tool message with no preceding tool_calls should not crash.""" + messages = [ + {"role": "tool", "content": "result"}, + ] + result = trace_instance._resolve_tool_call_ids(messages) + assert "tool_call_id" not in result[0] + + +# ── api_check ──────────────────────────────────────────────────────────────── + + +class TestApiCheck: + def test_success(self, trace_instance, mock_mlflow): + mock_mlflow.search_experiments.return_value = [] + assert trace_instance.api_check() is True + + def test_failure(self, trace_instance, mock_mlflow): + mock_mlflow.search_experiments.side_effect = ConnectionError("refused") + with pytest.raises(ValueError, match="MLflow connection failed"): + trace_instance.api_check() + + +# ── get_project_url ────────────────────────────────────────────────────────── + + +class TestGetProjectUrl: + def test_returns_url(self, trace_instance): + assert "experiments" in trace_instance.get_project_url() diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py new file mode 100644 index 0000000000..80a0331c4b --- /dev/null +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -0,0 +1,678 @@ +import collections +import logging +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.ops.entities.config_entity import OpikConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) +from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey +from models import EndUser +from models.enums import MessageStatus + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +@pytest.fixture +def opik_config(): + return OpikConfig( + project="test-project", workspace="test-workspace", url="https://cloud.opik.com/api/", api_key="api-key-123" + ) + + +@pytest.fixture +def trace_instance(opik_config, monkeypatch): + mock_client = MagicMock() + monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client) + + instance = OpikDataTrace(opik_config) + return instance + + +def test_wrap_dict(): + assert wrap_dict("input", {"a": 1}) == {"a": 1} + assert wrap_dict("input", "hello") == {"input": "hello"} + + +def test_wrap_metadata(): + assert wrap_metadata({"a": 1}, b=2) == {"a": 1, "b": 2, "created_from": "dify"} + + +def test_prepare_opik_uuid(): + # Test with valid datetime and uuid string + dt = datetime(2024, 1, 1) + uuid_str = "b3e8e918-472e-4b69-8051-12502c34fc07" + result = prepare_opik_uuid(dt, uuid_str) + assert result is not None + # We won't test the exact uuid7 value but just that it returns a string id + + # Test with None dt and uuid_str + result = prepare_opik_uuid(None, None) + assert result is not None + + +def test_init(opik_config, monkeypatch): + mock_opik = MagicMock() + monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik) + monkeypatch.setenv("FILES_URL", "http://test.url") + + instance = OpikDataTrace(opik_config) + + mock_opik.assert_called_once_with( + project_name=opik_config.project, + workspace=opik_config.workspace, + host=opik_config.url, + api_key=opik_config.api_key, + ) + assert instance.file_base_url == "http://test.url" + assert instance.project == opik_config.project + + +def test_trace_dispatch(trace_instance, monkeypatch): + methods = [ + "workflow_trace", + "message_trace", + "moderation_trace", + "suggested_question_trace", + "dataset_retrieval_trace", + "tool_trace", + "generate_name_trace", + ] + mocks = {method: MagicMock() for method in methods} + for method, m in mocks.items(): + monkeypatch.setattr(trace_instance, method, m) + + # WorkflowTraceInfo + info = MagicMock(spec=WorkflowTraceInfo) + trace_instance.trace(info) + mocks["workflow_trace"].assert_called_once_with(info) + + # MessageTraceInfo + info = MagicMock(spec=MessageTraceInfo) + trace_instance.trace(info) + mocks["message_trace"].assert_called_once_with(info) + + # ModerationTraceInfo + info = MagicMock(spec=ModerationTraceInfo) + trace_instance.trace(info) + mocks["moderation_trace"].assert_called_once_with(info) + + # SuggestedQuestionTraceInfo + info = MagicMock(spec=SuggestedQuestionTraceInfo) + trace_instance.trace(info) + mocks["suggested_question_trace"].assert_called_once_with(info) + + # DatasetRetrievalTraceInfo + info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_instance.trace(info) + mocks["dataset_retrieval_trace"].assert_called_once_with(info) + + # ToolTraceInfo + info = MagicMock(spec=ToolTraceInfo) + trace_instance.trace(info) + mocks["tool_trace"].assert_called_once_with(info) + + # GenerateNameTraceInfo + info = MagicMock(spec=GenerateNameTraceInfo) + trace_instance.trace(info) + mocks["generate_name_trace"].assert_called_once_with(info) + + +def test_workflow_trace_with_message_id(trace_instance, monkeypatch): + # Define constants for better readability + WORKFLOW_ID = "fb05c7cd-6cec-4add-8a84-df03a408b4ce" + WORKFLOW_RUN_ID = "33c67568-7a8a-450e-8916-a5f135baeaef" + MESSAGE_ID = "04ec3956-85f3-488a-8539-1017251dc8c6" + CONVERSATION_ID = "d3d01066-23ae-4830-9ce4-eb5640b42a7e" + TRACE_ID = "bf26d929-6f15-4c2f-9abc-761c217056f3" + WORKFLOW_APP_LOG_ID = "ca0e018e-edd4-43fb-a05a-ea001ca8ef4b" + LLM_NODE_ID = "80d7dfa8-08f4-4ab7-aa37-0ca7d27207e3" + CODE_NODE_ID = "b9cd9a7b-c534-4aa9-b5da-efd454140900" + + trace_info = WorkflowTraceInfo( + workflow_id=WORKFLOW_ID, + tenant_id="tenant-1", + workflow_run_id=WORKFLOW_RUN_ID, + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={"input": "hi"}, + workflow_run_outputs={"output": "hello"}, + workflow_run_version="1.0", + message_id=MESSAGE_ID, + conversation_id=CONVERSATION_ID, + total_tokens=100, + file_list=[], + query="hi", + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id=TRACE_ID, + metadata={"app_id": "app-1", "user_id": "user-1"}, + workflow_app_log_id=WORKFLOW_APP_LOG_ID, + error="", + ) + + mock_session = MagicMock() + monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + + node_llm = MagicMock() + node_llm.id = LLM_NODE_ID + node_llm.title = "LLM Node" + node_llm.node_type = NodeType.LLM + node_llm.status = "succeeded" + node_llm.process_data = { + "model_mode": "chat", + "model_name": "gpt-4", + "model_provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + } + node_llm.inputs = {"prompts": "p"} + node_llm.outputs = {"text": "t"} + node_llm.created_at = _dt() + node_llm.elapsed_time = 0.5 + node_llm.metadata = {"foo": "bar"} + + node_other = MagicMock() + node_other.id = CODE_NODE_ID + node_other.title = "Other Node" + node_other.node_type = NodeType.CODE + node_other.status = "failed" + node_other.process_data = None + node_other.inputs = {"code": "print"} + node_other.outputs = {"result": "ok"} + node_other.created_at = None + node_other.elapsed_time = 0.2 + node_other.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS.value: 10} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node_llm, node_other] + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.add_span = MagicMock() + + trace_instance.workflow_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_data = trace_instance.add_trace.call_args[1].get("opik_trace_data", trace_instance.add_trace.call_args[0][0]) + assert trace_data["name"] == TraceTaskName.MESSAGE_TRACE + assert "message" in trace_data["tags"] + assert "workflow" in trace_data["tags"] + + assert trace_instance.add_span.call_count >= 1 + + +def test_workflow_trace_no_message_id(trace_instance, monkeypatch): + # Define constants for better readability + WORKFLOW_ID = "f0708b36-b1d7-42b3-a876-1d01b7d8f1a3" + WORKFLOW_RUN_ID = "d42ec285-c2fd-4248-8866-5c9386b101ac" + CONVERSATION_ID = "88a17f2e-9436-4472-bab9-4b1601d5af3c" + WORKFLOW_APP_LOG_ID = "41780d0d-ffba-4220-bc0c-401e4c89cdfb" + + trace_info = WorkflowTraceInfo( + workflow_id=WORKFLOW_ID, + tenant_id="tenant-1", + workflow_run_id=WORKFLOW_RUN_ID, + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1.0", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id=CONVERSATION_ID, + start_time=_dt(), + end_time=_dt(), + trace_id=None, + metadata={"app_id": "app-1"}, + workflow_app_log_id=WORKFLOW_APP_LOG_ID, + error="", + ) + + monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + repo = MagicMock() + repo.get_by_workflow_run.return_value = [] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.workflow_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + + +def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): + trace_info = WorkflowTraceInfo( + workflow_id="5745f1b8-f8e6-4859-8110-996acb6c8d6a", + tenant_id="tenant-1", + workflow_run_id="46f53304-1659-464b-bee5-116585f0bec8", + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1.0", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="83f86b89-caef-4de8-a0f9-f164eddae1ea", + start_time=_dt(), + end_time=_dt(), + metadata={}, + workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e", + error="", + ) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(trace_info) + + +def test_message_trace_basic(trace_instance, monkeypatch): + # Define constants for better readability + MESSAGE_DATA_ID = "e3a26712-8cac-4a25-94a4-a3bff21ee3ab" + CONVERSATION_ID = "9d3f3751-7521-4c19-9307-20e3cf6789a3" + MESSAGE_TRACE_ID = "710ace2f-bca8-41be-858c-54da42742a77" + OPIT_TRACE_ID = "f7dfd978-0d10-4549-8abf-00f2cbc49d2c" + + message_data = MagicMock() + message_data.id = MESSAGE_DATA_ID + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = None + message_data.provider_response_latency = 0.5 + message_data.conversation_id = CONVERSATION_ID + message_data.total_price = 0.01 + message_data.model_id = "gpt-4" + message_data.answer = "hello" + message_data.status = MessageStatus.NORMAL + message_data.error = None + + trace_info = MessageTraceInfo( + message_id=MESSAGE_TRACE_ID, + message_data=message_data, + inputs={"query": "hi"}, + outputs={"answer": "hello"}, + message_tokens=10, + answer_tokens=20, + total_tokens=30, + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id=OPIT_TRACE_ID, + metadata={"foo": "bar"}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=[], + error=None, + message_file_data=MagicMock(url="test.png"), + ) + + trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_1")) + trace_instance.add_span = MagicMock() + + trace_instance.message_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_instance.add_span.assert_called_once() + + +def test_message_trace_with_end_user(trace_instance, monkeypatch): + message_data = MagicMock() + message_data.id = "85411059-79fb-4deb-a76c-c2e215f1b97e" + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = "end-user-1" + message_data.conversation_id = "7d9f96d8-3be2-4e93-9c0e-922ff98dccc6" + message_data.status = MessageStatus.NORMAL + message_data.model_id = "gpt-4" + message_data.error = "" + message_data.answer = "hello" + message_data.total_price = 0.0 + message_data.provider_response_latency = 0.1 + + trace_info = MessageTraceInfo( + message_id="6bff35c7-33b7-4acb-ba21-44569a0327d0", + message_data=message_data, + inputs={}, + outputs={}, + message_tokens=0, + answer_tokens=0, + total_tokens=0, + start_time=_dt(), + end_time=_dt(), + metadata={}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=["url1"], + error=None, + ) + + mock_end_user = MagicMock(spec=EndUser) + mock_end_user.session_id = "session-id-123" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_end_user + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.query", lambda model: mock_query) + + trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) + trace_instance.add_span = MagicMock() + + trace_instance.message_trace(trace_info) + + trace_data = trace_instance.add_trace.call_args[0][0] + assert trace_data["metadata"]["user_id"] == "acc-1" + assert trace_data["metadata"]["end_user_id"] == "session-id-123" + + +def test_message_trace_none_data(trace_instance): + trace_info = SimpleNamespace(message_data=None, file_list=[], message_file_data=None, metadata={}) + trace_instance.add_trace = MagicMock() + trace_instance.message_trace(trace_info) + trace_instance.add_trace.assert_not_called() + + +def test_moderation_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = ModerationTraceInfo( + message_id="489d0dfd-065c-4106-8f9c-daded296c92d", + message_data=message_data, + inputs={"q": "hi"}, + action="stop", + flagged=True, + preset_response="blocked", + start_time=None, + end_time=None, + metadata={"foo": "bar"}, + trace_id="6f16cf18-9f4b-4955-8b6b-43cfa10978fc", + query="hi", + ) + + trace_instance.add_span = MagicMock() + trace_instance.moderation_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["name"] == TraceTaskName.MODERATION_TRACE + assert span_data["output"]["flagged"] is True + + +def test_moderation_trace_none(trace_instance): + trace_info = ModerationTraceInfo( + message_id="cd732e4e-37f1-4c7e-8c64-820308bedcbf", + message_data=None, + inputs={}, + action="s", + flagged=False, + preset_response="", + query="", + metadata={}, + ) + trace_instance.add_span = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_suggested_question_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = SuggestedQuestionTraceInfo( + message_id="7de55bda-a91d-477e-98ab-85c53c438469", + message_data=message_data, + inputs="hi", + suggested_question=["q1"], + total_tokens=10, + level="info", + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="a6687292-68c7-42ba-ae51-285579944d7b", + ) + + trace_instance.add_span = MagicMock() + trace_instance.suggested_question_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["name"] == TraceTaskName.SUGGESTED_QUESTION_TRACE + + +def test_suggested_question_trace_none(trace_instance): + trace_info = SuggestedQuestionTraceInfo( + message_id="23696fc5-7e7f-46ec-bce8-1adc3c7f297d", + message_data=None, + inputs={}, + suggested_question=[], + total_tokens=0, + level="i", + metadata={}, + ) + trace_instance.add_span = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_dataset_retrieval_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = DatasetRetrievalTraceInfo( + message_id="3e1a819f-c391-4950-adfd-96f82e5419a1", + message_data=message_data, + inputs="query", + documents=[{"id": "doc1"}], + start_time=None, + end_time=None, + metadata={}, + trace_id="41361000-e9be-4d11-b5e4-ab27ce0817d6", + ) + + trace_instance.add_span = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["name"] == TraceTaskName.DATASET_RETRIEVAL_TRACE + + +def test_dataset_retrieval_trace_none(trace_instance): + trace_info = DatasetRetrievalTraceInfo( + message_id="35d6d44c-bccb-4e6e-8bd8-859257723ea8", message_data=None, inputs={}, documents=[], metadata={} + ) + trace_instance.add_span = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_tool_trace(trace_instance): + trace_info = ToolTraceInfo( + message_id="99db92c4-2254-496a-b5cc-18153315ce35", + message_data=MagicMock(), + inputs={}, + outputs={}, + tool_name="my_tool", + tool_inputs={"a": 1}, + tool_outputs="result_string", + time_cost=0.1, + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="a15a5fcb-7ffd-4458-8330-208f4cb1f796", + tool_config={}, + tool_parameters={}, + error="some error", + ) + + trace_instance.add_span = MagicMock() + trace_instance.tool_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["name"] == "my_tool" + + +def test_generate_name_trace(trace_instance): + trace_info = GenerateNameTraceInfo( + inputs={"q": "hi"}, + outputs={"name": "new"}, + tenant_id="tenant-1", + conversation_id="271fe28f-6b86-416b-8d6b-bbbbfa9db791", + start_time=_dt(), + end_time=_dt(), + metadata={"921f010e-6878-4831-ae6b-271bf68c56fb": 1}, + ) + + trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_3")) + trace_instance.add_span = MagicMock() + + trace_instance.generate_name_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_instance.add_span.assert_called_once() + + trace_data = trace_instance.add_trace.call_args[0][0] + assert trace_data["name"] == TraceTaskName.GENERATE_NAME_TRACE + + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["trace_id"] == "trace_id_3" + + +def test_add_trace_success(trace_instance): + trace_data = {"id": "t1", "name": "trace"} + trace_instance.opik_client.trace.return_value = MagicMock(id="t1") + trace = trace_instance.add_trace(trace_data) + trace_instance.opik_client.trace.assert_called_once() + assert trace.id == "t1" + + +def test_add_trace_error(trace_instance): + trace_instance.opik_client.trace.side_effect = Exception("error") + trace_data = {"id": "t1", "name": "trace"} + with pytest.raises(ValueError, match="Opik Failed to create trace: error"): + trace_instance.add_trace(trace_data) + + +def test_add_span_success(trace_instance): + span_data = {"id": "s1", "name": "span", "trace_id": "t1"} + trace_instance.add_span(span_data) + trace_instance.opik_client.span.assert_called_once() + + +def test_add_span_error(trace_instance): + trace_instance.opik_client.span.side_effect = Exception("error") + span_data = {"id": "s1", "name": "span", "trace_id": "t1"} + with pytest.raises(ValueError, match="Opik Failed to create span: error"): + trace_instance.add_span(span_data) + + +def test_api_check_success(trace_instance): + trace_instance.opik_client.auth_check.return_value = True + assert trace_instance.api_check() is True + + +def test_api_check_error(trace_instance): + trace_instance.opik_client.auth_check.side_effect = Exception("fail") + with pytest.raises(ValueError, match="Opik API check failed: fail"): + trace_instance.api_check() + + +def test_get_project_url_success(trace_instance): + trace_instance.opik_client.get_project_url.return_value = "http://project.url" + assert trace_instance.get_project_url() == "http://project.url" + trace_instance.opik_client.get_project_url.assert_called_once_with(project_name=trace_instance.project) + + +def test_get_project_url_error(trace_instance): + trace_instance.opik_client.get_project_url.side_effect = Exception("fail") + with pytest.raises(ValueError, match="Opik get run url failed: fail"): + trace_instance.get_project_url() + + +def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch, caplog): + trace_info = WorkflowTraceInfo( + workflow_id="86a52565-4a6b-4a1b-9bfd-98e4595e70de", + tenant_id="66e8e918-472e-4b69-8051-12502c34fc07", + workflow_run_id="8403965c-3344-4d22-a8fe-d8d55cee64d9", + workflow_run_elapsed_time=1.0, + workflow_run_status="s", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="7a02cb9d-6949-4c59-a89d-f25bbc881e0e", + start_time=_dt(), + end_time=_dt(), + metadata={"app_id": "77e8e918-472e-4b69-8051-12502c34fc07"}, + workflow_app_log_id="82268424-e193-476c-a6db-f473388ee5fe", + error="", + ) + + node = MagicMock() + node.id = "88e8e918-472e-4b69-8051-12502c34fc07" + node.title = "LLM Node" + node.node_type = NodeType.LLM + node.status = "succeeded" + + class BadDict(collections.UserDict): + def get(self, key, default=None): + if key == "usage": + raise Exception("Usage extraction failed") + return super().get(key, default) + + node.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]}) + node.created_at = _dt() + node.elapsed_time = 0.1 + node.metadata = {} + node.outputs = {} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.add_span = MagicMock() + + with caplog.at_level(logging.ERROR): + trace_instance.workflow_trace(trace_info) + + assert "Failed to extract usage" in caplog.text + assert trace_instance.add_span.call_count >= 1 + # Verify that at least one of the spans is for the LLM Node + span_names = [call.args[0]["name"] for call in trace_instance.add_span.call_args_list] + assert "LLM Node" in span_names diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py b/api/tests/unit_tests/core/ops/tencent_trace/test_client.py new file mode 100644 index 0000000000..870c18e53e --- /dev/null +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_client.py @@ -0,0 +1,583 @@ +"""Tests for the TencentTraceClient helpers that drive tracing and metrics.""" + +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from opentelemetry.sdk.trace import Event +from opentelemetry.trace import Status, StatusCode + +from core.ops.tencent_trace import client as client_module +from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version +from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData + +metric_reader_instances: list[DummyMetricReader] = [] +meter_provider_instances: list[DummyMeterProvider] = [] + + +class DummyHistogram: + """Placeholder histogram type used by the stubbed metric stack.""" + + +class AggregationTemporality: + DELTA = "delta" + + +class DummyMeter: + def __init__(self) -> None: + self.created: list[tuple[dict[str, object], MagicMock]] = [] + + def create_histogram(self, **kwargs: object) -> MagicMock: + hist = MagicMock(name=f"hist-{kwargs.get('name')}") + self.created.append((kwargs, hist)) + return hist + + +class DummyMeterProvider: + def __init__(self, resource: object, metric_readers: list[object]) -> None: + self.resource = resource + self.metric_readers = metric_readers + self.meter = DummyMeter() + self.shutdown = MagicMock(name="meter_provider_shutdown") + meter_provider_instances.append(self) + + def get_meter(self, name: str, version: str) -> DummyMeter: + return self.meter + + +class DummyMetricReader: + def __init__(self, exporter: object, export_interval_millis: int) -> None: + self.exporter = exporter + self.export_interval_millis = export_interval_millis + self.shutdown = MagicMock(name="metric_reader_shutdown") + metric_reader_instances.append(self) + + +class DummyGrpcMetricExporter: + def __init__(self, **kwargs: object) -> None: + self.kwargs = kwargs + + +class DummyHttpMetricExporter: + def __init__(self, **kwargs: object) -> None: + self.kwargs = kwargs + + +class DummyJsonMetricExporter: + def __init__(self, **kwargs: object) -> None: + self.kwargs = kwargs + + +class DummyJsonMetricExporterNoTemporality: + """Exporter that rejects preferred_temporality to exercise fallback.""" + + def __init__(self, **kwargs: object) -> None: + if "preferred_temporality" in kwargs: + raise RuntimeError("unsupported preferred_temporality") + self.kwargs = kwargs + + +def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None: + """Drop fake metric modules into sys.modules so the client imports resolve.""" + + metrics_module = types.ModuleType("opentelemetry.sdk.metrics") + metrics_module.Histogram = DummyHistogram + metrics_module.MeterProvider = DummyMeterProvider + monkeypatch.setitem(sys.modules, "opentelemetry.sdk.metrics", metrics_module) + + metrics_export_module = types.ModuleType("opentelemetry.sdk.metrics.export") + metrics_export_module.AggregationTemporality = AggregationTemporality + metrics_export_module.PeriodicExportingMetricReader = DummyMetricReader + monkeypatch.setitem(sys.modules, "opentelemetry.sdk.metrics.export", metrics_export_module) + + grpc_module = types.ModuleType("opentelemetry.exporter.otlp.proto.grpc.metric_exporter") + grpc_module.OTLPMetricExporter = DummyGrpcMetricExporter + monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.proto.grpc.metric_exporter", grpc_module) + + http_module = types.ModuleType("opentelemetry.exporter.otlp.proto.http.metric_exporter") + http_module.OTLPMetricExporter = DummyHttpMetricExporter + monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.proto.http.metric_exporter", http_module) + + http_json_module = types.ModuleType("opentelemetry.exporter.otlp.http.json.metric_exporter") + http_json_module.OTLPMetricExporter = DummyJsonMetricExporter + monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.http.json.metric_exporter", http_json_module) + + legacy_json_module = types.ModuleType("opentelemetry.exporter.otlp.json.metric_exporter") + legacy_json_module.OTLPMetricExporter = DummyJsonMetricExporter + monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.json.metric_exporter", legacy_json_module) + + +@pytest.fixture(autouse=True) +def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None: + metric_reader_instances.clear() + meter_provider_instances.clear() + _add_stub_modules(monkeypatch) + + +@pytest.fixture(autouse=True) +def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]: + span_exporter = MagicMock(name="span_exporter") + monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter)) + + span_processor = MagicMock(name="span_processor") + monkeypatch.setattr(client_module, "BatchSpanProcessor", MagicMock(return_value=span_processor)) + + tracer = MagicMock(name="tracer") + span = MagicMock(name="span") + tracer.start_span.return_value = span + + tracer_provider = MagicMock(name="tracer_provider") + tracer_provider.get_tracer.return_value = tracer + tracer_provider.shutdown = MagicMock(name="tracer_provider_shutdown") + monkeypatch.setattr(client_module, "TracerProvider", MagicMock(return_value=tracer_provider)) + + resource = MagicMock(name="resource") + monkeypatch.setattr(client_module, "Resource", MagicMock(return_value=resource)) + + logger_mock = MagicMock(name="tencent_logger") + monkeypatch.setattr(client_module, "logger", logger_mock) + + trace_api_stub = SimpleNamespace( + set_span_in_context=MagicMock(name="set_span_in_context", return_value="trace-context"), + NonRecordingSpan=MagicMock(name="non_recording_span", side_effect=lambda ctx: f"non-{ctx}"), + ) + monkeypatch.setattr(client_module, "trace_api", trace_api_stub) + + fake_config = SimpleNamespace( + project=SimpleNamespace(version="test"), + COMMIT_SHA="sha", + DEPLOY_ENV="dev", + EDITION="cloud", + ) + monkeypatch.setattr(client_module, "dify_config", fake_config) + + monkeypatch.setattr(client_module.socket, "gethostname", lambda: "fake-host") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "") + + return { + "span_exporter": span_exporter, + "span_processor": span_processor, + "tracer": tracer, + "span": span, + "tracer_provider": tracer_provider, + "logger": logger_mock, + "trace_api": trace_api_stub, + } + + +def _build_client() -> TencentTraceClient: + return TencentTraceClient( + service_name="service", + endpoint="https://trace.example.com:4317", + token="token", + ) + + +def test_get_opentelemetry_sdk_version_reads_install(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(client_module, "version", lambda pkg: "2.0.0") + assert _get_opentelemetry_sdk_version() == "2.0.0" + + +def test_get_opentelemetry_sdk_version_falls_back(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(client_module, "version", MagicMock(side_effect=RuntimeError("boom"))) + assert _get_opentelemetry_sdk_version() == "1.27.0" + + +@pytest.mark.parametrize( + ("endpoint", "expected"), + [ + ( + "https://example.com:9090", + ("example.com:9090", False, "example.com", 9090), + ), + ( + "http://localhost", + ("localhost:4317", True, "localhost", 4317), + ), + ( + "example.com:bad", + ("example.com:4317", False, "example.com", 4317), + ), + ], +) +def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[str, bool, str, int]) -> None: + assert TencentTraceClient._resolve_grpc_target(endpoint) == expected + + +def test_resolve_grpc_target_handles_errors() -> None: + assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317) + + +@pytest.mark.parametrize( + ("method", "attr_name", "args"), + [ + ("record_llm_duration", "hist_llm_duration", (0.3, {"foo": object()})), + ("record_token_usage", "hist_token_usage", (5, "input", "chat", "gpt", "gpt", "addr", "dify")), + ("record_time_to_first_token", "hist_time_to_first_token", (0.4, "dify", "gpt")), + ("record_time_to_generate", "hist_time_to_generate", (0.6, "dify", "gpt")), + ("record_trace_duration", "hist_trace_duration", (1.0, {"meta": object()})), + ], +) +def test_record_methods_call_histograms(method: str, attr_name: str, args: tuple[object, ...]) -> None: + client = _build_client() + hist_mock = MagicMock(name=attr_name) + setattr(client, attr_name, hist_mock) + + getattr(client, method)(*args) + hist_mock.record.assert_called_once() + + +def test_record_methods_skip_when_histogram_missing() -> None: + client = _build_client() + client.hist_llm_duration = None + client.record_llm_duration(0.1) + + client.hist_token_usage = None + client.record_token_usage(1, "go", "chat", "model", "model", "addr", "provider") + + client.hist_time_to_first_token = None + client.record_time_to_first_token(0.2, "prov", "model") + + client.hist_time_to_generate = None + client.record_time_to_generate(0.3, "prov", "model") + + client.hist_trace_duration = None + client.record_trace_duration(0.5) + + +def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None: + client = _build_client() + client.hist_llm_duration = MagicMock(name="hist_llm_duration") + client.hist_llm_duration.record.side_effect = RuntimeError("boom") + + client.record_llm_duration(0.2) + logger = patch_core_components["logger"] + logger.debug.assert_called() + + +def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None: + client = _build_client() + span = patch_core_components["span"] + span.get_span_context.return_value = "ctx" + + data = SpanData( + trace_id=1, + parent_span_id=None, + span_id=2, + name="span", + attributes={"key": "value"}, + events=[Event(name="evt", attributes={"k": "v"}, timestamp=123)], + status=Status(StatusCode.OK), + start_time=10, + end_time=20, + ) + + client._create_and_export_span(data) + span.set_attributes.assert_called_once() + span.add_event.assert_called_once() + span.set_status.assert_called_once() + span.end.assert_called_once_with(end_time=20) + assert client.span_contexts[2] == "ctx" + + +def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None: + client = _build_client() + client.span_contexts[10] = "existing" + span = patch_core_components["span"] + span.get_span_context.return_value = "child" + + data = SpanData( + trace_id=1, + parent_span_id=10, + span_id=11, + name="span", + attributes={}, + events=[], + start_time=0, + end_time=1, + ) + + client._create_and_export_span(data) + trace_api = patch_core_components["trace_api"] + trace_api.NonRecordingSpan.assert_called_once_with("existing") + trace_api.set_span_in_context.assert_called_once() + + +def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None: + client = _build_client() + span = patch_core_components["span"] + span.get_span_context.return_value = "ctx" + client.tracer.start_span.side_effect = RuntimeError("boom") + + client._create_and_export_span( + SpanData( + trace_id=1, + parent_span_id=None, + span_id=2, + name="span", + attributes={}, + events=[], + start_time=0, + end_time=1, + ) + ) + logger = patch_core_components["logger"] + logger.exception.assert_called_once() + + +def test_api_check_connects_successfully(monkeypatch: pytest.MonkeyPatch) -> None: + client = _build_client() + + monkeypatch.setattr( + TencentTraceClient, + "_resolve_grpc_target", + MagicMock(return_value=("host:123", False, "host", 123)), + ) + + socket_mock = MagicMock() + socket_instance = MagicMock() + socket_instance.connect_ex.return_value = 0 + socket_mock.return_value = socket_instance + monkeypatch.setattr(client_module.socket, "socket", socket_mock) + + assert client.api_check() + socket_instance.connect_ex.assert_called_once() + + +def test_api_check_returns_false_and_handles_local(monkeypatch: pytest.MonkeyPatch) -> None: + client = _build_client() + + monkeypatch.setattr( + TencentTraceClient, + "_resolve_grpc_target", + MagicMock(return_value=("host:123", False, "host", 123)), + ) + + socket_mock = MagicMock() + socket_instance = MagicMock() + socket_instance.connect_ex.return_value = 1 + socket_mock.return_value = socket_instance + monkeypatch.setattr(client_module.socket, "socket", socket_mock) + + assert not client.api_check() + + monkeypatch.setattr( + TencentTraceClient, + "_resolve_grpc_target", + MagicMock(return_value=("localhost:4317", True, "localhost", 4317)), + ) + socket_instance.connect_ex.return_value = 1 + assert client.api_check() + + +def test_api_check_handles_exceptions(monkeypatch: pytest.MonkeyPatch) -> None: + client = TencentTraceClient("svc", "https://localhost", "token") + + monkeypatch.setattr(client_module.socket, "socket", MagicMock(side_effect=RuntimeError("boom"))) + assert client.api_check() + + +def test_get_project_url() -> None: + client = _build_client() + assert client.get_project_url() == "https://console.cloud.tencent.com/apm" + + +def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None: + client = _build_client() + span_processor = patch_core_components["span_processor"] + tracer_provider = patch_core_components["tracer_provider"] + + client.shutdown() + span_processor.force_flush.assert_called_once() + span_processor.shutdown.assert_called_once() + tracer_provider.shutdown.assert_called_once() + + meter_provider = meter_provider_instances[-1] + metric_reader = metric_reader_instances[-1] + meter_provider.shutdown.assert_called_once() + metric_reader.shutdown.assert_called_once() + + +def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None: + client = _build_client() + meter_provider = meter_provider_instances[-1] + meter_provider.shutdown.side_effect = RuntimeError("boom") + client.metric_reader.shutdown.side_effect = RuntimeError("boom") + + client.shutdown() + logger = patch_core_components["logger"] + logger.debug.assert_any_call( + "[Tencent APM] Error shutting down meter provider", + exc_info=True, + ) + logger.debug.assert_any_call( + "[Tencent APM] Error shutting down metric reader", + exc_info=True, + ) + + +def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(DummyMeterProvider, "__init__", MagicMock(side_effect=RuntimeError("err"))) + client = _build_client() + + assert client.meter is None + assert client.meter_provider is None + assert client.hist_llm_duration is None + assert client.hist_token_usage is None + assert client.hist_time_to_first_token is None + assert client.hist_time_to_generate is None + assert client.hist_trace_duration is None + assert client.metric_reader is None + + +def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None: + client = _build_client() + monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom"))) + + client.add_span( + SpanData( + trace_id=1, + parent_span_id=None, + span_id=2, + name="span", + attributes={}, + events=[], + start_time=0, + end_time=1, + ) + ) + + logger = patch_core_components["logger"] + logger.exception.assert_called_once() + + +def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None: + client = _build_client() + span = patch_core_components["span"] + span.get_span_context.return_value = "ctx" + + data = SpanData.model_construct( + trace_id=1, + parent_span_id=None, + span_id=2, + name="span", + attributes={"num": 5, "flag": True, "pi": 3.14, "text": "value"}, + events=[], + links=[], + status=Status(StatusCode.OK), + start_time=0, + end_time=1, + ) + + client._create_and_export_span(data) + (attrs,) = span.set_attributes.call_args.args + assert attrs["num"] == 5 + assert attrs["flag"] is True + assert attrs["pi"] == 3.14 + assert attrs["text"] == "value" + + +def test_record_llm_duration_converts_attributes() -> None: + client = _build_client() + hist_mock = MagicMock(name="hist_llm_duration") + client.hist_llm_duration = hist_mock + + client.record_llm_duration(0.3, {"foo": object(), "bar": 2}) + _, attrs = hist_mock.record.call_args.args + assert isinstance(attrs["foo"], str) + assert attrs["bar"] == 2 + + +def test_record_trace_duration_converts_attributes() -> None: + client = _build_client() + hist_mock = MagicMock(name="hist_trace_duration") + client.hist_trace_duration = hist_mock + + client.record_trace_duration(1.0, {"meta": object(), "ok": True}) + _, attrs = hist_mock.record.call_args.args + assert isinstance(attrs["meta"], str) + assert attrs["ok"] is True + + +@pytest.mark.parametrize( + ("method", "attr_name", "args"), + [ + ("record_token_usage", "hist_token_usage", (5, "input", "chat", "gpt", "gpt", "addr", "dify")), + ("record_time_to_first_token", "hist_time_to_first_token", (0.4, "dify", "gpt")), + ("record_time_to_generate", "hist_time_to_generate", (0.6, "dify", "gpt")), + ("record_trace_duration", "hist_trace_duration", (1.0, {"meta": object()})), + ], +) +def test_record_methods_handle_exceptions( + method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object] +) -> None: + client = _build_client() + hist_mock = MagicMock(name=attr_name) + hist_mock.record.side_effect = RuntimeError("boom") + setattr(client, attr_name, hist_mock) + + getattr(client, method)(*args) + logger = patch_core_components["logger"] + logger.debug.assert_called() + + +def test_metrics_initializes_grpc_metric_exporter() -> None: + client = _build_client() + metric_reader = metric_reader_instances[-1] + + assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter) + assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 + assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317" + assert metric_reader.exporter.kwargs["insecure"] is False + assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + + +def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf") + client = _build_client() + metric_reader = metric_reader_instances[-1] + + assert isinstance(metric_reader.exporter, DummyHttpMetricExporter) + assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 + assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint + assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + + +def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") + client = _build_client() + metric_reader = metric_reader_instances[-1] + + assert isinstance(metric_reader.exporter, DummyJsonMetricExporter) + assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 + assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint + assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + assert "preferred_temporality" in metric_reader.exporter.kwargs + + +def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") + exporter_module = sys.modules["opentelemetry.exporter.otlp.http.json.metric_exporter"] + monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality) + _ = _build_client() + metric_reader = metric_reader_instances[-1] + + assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality) + assert "preferred_temporality" not in metric_reader.exporter.kwargs + + +def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") + + def _fail_import(mod_path: str) -> types.ModuleType: + raise ModuleNotFoundError(mod_path) + + monkeypatch.setattr(client_module.importlib, "import_module", _fail_import) + + _ = _build_client() + metric_reader = metric_reader_instances[-1] + assert isinstance(metric_reader.exporter, DummyHttpMetricExporter) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py new file mode 100644 index 0000000000..a0b6d52720 --- /dev/null +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py @@ -0,0 +1,359 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +from opentelemetry.trace import StatusCode + +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + MessageTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.ops.tencent_trace.entities.semconv import ( + GEN_AI_IS_ENTRY, + GEN_AI_IS_STREAMING_REQUEST, + GEN_AI_MODEL_NAME, + GEN_AI_SPAN_KIND, + GEN_AI_USAGE_INPUT_TOKENS, + INPUT_VALUE, + RETRIEVAL_DOCUMENT, + RETRIEVAL_QUERY, + TOOL_DESCRIPTION, + TOOL_NAME, + TOOL_PARAMETERS, + GenAISpanKind, +) +from core.ops.tencent_trace.span_builder import TencentSpanBuilder +from core.rag.models.document import Document +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus + + +class TestTencentSpanBuilder: + def test_get_time_nanoseconds(self): + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: + mock_convert.return_value = 123456789 + dt = datetime.now() + result = TencentSpanBuilder._get_time_nanoseconds(dt) + assert result == 123456789 + mock_convert.assert_called_once_with(dt) + + def test_build_workflow_spans(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run_id" + trace_info.error = None + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.workflow_run_inputs = {"sys.query": "hello"} + trace_info.workflow_run_outputs = {"answer": "world"} + trace_info.metadata = {"conversation_id": "conv_id"} + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") + + assert len(spans) == 2 + assert spans[0].name == "message" + assert spans[0].span_id == 2 + assert spans[1].name == "workflow" + assert spans[1].span_id == 1 + assert spans[1].parent_span_id == 2 + + def test_build_workflow_spans_no_message(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run_id" + trace_info.error = "some error" + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.workflow_run_inputs = {} + trace_info.workflow_run_outputs = {} + trace_info.metadata = {} # No conversation_id + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 1 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") + + assert len(spans) == 1 + assert spans[0].name == "workflow" + assert spans[0].status.status_code == StatusCode.ERROR + assert spans[0].status.description == "some error" + assert spans[0].attributes[GEN_AI_IS_ENTRY] == "true" + + def test_build_workflow_llm_span(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"conversation_id": "conv_id"} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.process_data = { + "model_name": "gpt-4", + "model_provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "time_to_first_token": 0.5}, + "prompts": ["hello"], + } + node_execution.outputs = {"text": "world"} + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 456 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) + + assert span.name == "GENERATION" + assert span.attributes[GEN_AI_MODEL_NAME] == "gpt-4" + assert span.attributes[GEN_AI_IS_STREAMING_REQUEST] == "true" + assert span.attributes[GEN_AI_USAGE_INPUT_TOKENS] == "10" + + def test_build_workflow_llm_span_usage_in_outputs(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.process_data = {} + node_execution.outputs = { + "text": "world", + "usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40}, + } + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 456 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) + + assert span.attributes[GEN_AI_USAGE_INPUT_TOKENS] == "15" + assert GEN_AI_IS_STREAMING_REQUEST not in span.attributes + + def test_build_message_span_standalone(self): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.message_id = "msg_id" + trace_info.error = None + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.inputs = {"q": "hi"} + trace_info.outputs = "hello" + trace_info.metadata = {"conversation_id": "conv_id"} + trace_info.is_streaming_request = True + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 789 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") + + assert span.name == "message" + assert span.attributes[GEN_AI_IS_STREAMING_REQUEST] == "true" + assert span.attributes[INPUT_VALUE] == str(trace_info.inputs) + + def test_build_message_span_standalone_with_error(self): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.message_id = "msg_id" + trace_info.error = "some error" + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.inputs = None + trace_info.outputs = None + trace_info.metadata = {} + trace_info.is_streaming_request = False + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 789 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") + + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "some error" + assert span.attributes[INPUT_VALUE] == "" + + def test_build_tool_span(self): + trace_info = MagicMock(spec=ToolTraceInfo) + trace_info.message_id = "msg_id" + trace_info.tool_name = "search" + trace_info.error = "tool error" + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.tool_parameters = {"p": 1} + trace_info.tool_inputs = {"i": 2} + trace_info.tool_outputs = "result" + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 101 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1) + + assert span.name == "search" + assert span.status.status_code == StatusCode.ERROR + assert span.attributes[TOOL_NAME] == "search" + + def test_build_retrieval_span(self): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = "msg_id" + trace_info.inputs = "query" + trace_info.error = None + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + + doc = Document( + page_content="content", metadata={"dataset_id": "d1", "doc_id": "di1", "document_id": "du1", "score": 0.9} + ) + trace_info.documents = [doc] + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 202 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) + + assert span.name == "retrieval" + assert span.attributes[RETRIEVAL_QUERY] == "query" + assert "content" in span.attributes[RETRIEVAL_DOCUMENT] + + def test_build_retrieval_span_with_error(self): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = "msg_id" + trace_info.inputs = "" + trace_info.error = "retrieval failed" + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.documents = [] + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 202 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) + + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "retrieval failed" + + def test_get_workflow_node_status(self): + node = MagicMock(spec=WorkflowNodeExecution) + + node.status = WorkflowNodeExecutionStatus.SUCCEEDED + assert TencentSpanBuilder._get_workflow_node_status(node).status_code == StatusCode.OK + + node.status = WorkflowNodeExecutionStatus.FAILED + node.error = "fail" + status = TencentSpanBuilder._get_workflow_node_status(node) + assert status.status_code == StatusCode.ERROR + assert status.description == "fail" + + node.status = WorkflowNodeExecutionStatus.EXCEPTION + node.error = "exc" + status = TencentSpanBuilder._get_workflow_node_status(node) + assert status.status_code == StatusCode.ERROR + assert status.description == "exc" + + node.status = WorkflowNodeExecutionStatus.RUNNING + assert TencentSpanBuilder._get_workflow_node_status(node).status_code == StatusCode.UNSET + + def test_build_workflow_retrieval_span(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"conversation_id": "conv_id"} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my retrieval" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.inputs = {"query": "q1"} + node_execution.outputs = {"result": [{"content": "c1"}]} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 303 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) + + assert span.name == "my retrieval" + assert span.attributes[RETRIEVAL_QUERY] == "q1" + assert "c1" in span.attributes[RETRIEVAL_DOCUMENT] + + def test_build_workflow_retrieval_span_empty(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my retrieval" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.inputs = {} + node_execution.outputs = {} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 303 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) + + assert span.attributes[RETRIEVAL_QUERY] == "" + assert span.attributes[RETRIEVAL_DOCUMENT] == "" + + def test_build_workflow_tool_span(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my tool" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"info": "some"}} + node_execution.inputs = {"param": "val"} + node_execution.outputs = {"res": "ok"} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 404 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) + + assert span.name == "my tool" + assert span.attributes[TOOL_NAME] == "my tool" + assert "some" in span.attributes[TOOL_DESCRIPTION] + + def test_build_workflow_tool_span_no_metadata(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my tool" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.metadata = None + node_execution.inputs = None + node_execution.outputs = {"res": "ok"} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 404 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) + + assert span.attributes[TOOL_DESCRIPTION] == "{}" + assert span.attributes[TOOL_PARAMETERS] == "{}" + + def test_build_workflow_task_span(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"conversation_id": "conv_id"} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my task" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.inputs = {"in": 1} + node_execution.outputs = {"out": 2} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 505 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution) + + assert span.name == "my task" + assert span.attributes[GEN_AI_SPAN_KIND] == GenAISpanKind.TASK.value diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py new file mode 100644 index 0000000000..077a92d866 --- /dev/null +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -0,0 +1,647 @@ +import logging +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.config_entity import TencentConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.ops.tencent_trace.tencent_trace import TencentDataTrace +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import NodeType +from models import Account, App, TenantAccountJoin + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def tencent_config(): + return TencentConfig(service_name="test-service", endpoint="https://test-endpoint", token="test-token") + + +@pytest.fixture +def mock_trace_client(): + with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock: + yield mock + + +@pytest.fixture +def mock_span_builder(): + with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock: + yield mock + + +@pytest.fixture +def mock_trace_utils(): + with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock: + yield mock + + +@pytest.fixture +def tencent_data_trace(tencent_config, mock_trace_client): + return TencentDataTrace(tencent_config) + + +class TestTencentDataTrace: + def test_init(self, tencent_config, mock_trace_client): + trace = TencentDataTrace(tencent_config) + mock_trace_client.assert_called_once_with( + service_name=tencent_config.service_name, + endpoint=tencent_config.endpoint, + token=tencent_config.token, + metrics_export_interval_sec=5, + ) + assert trace.trace_client == mock_trace_client.return_value + + def test_trace_dispatch(self, tencent_data_trace): + methods = [ + ( + WorkflowTraceInfo( + workflow_id="wf", + tenant_id="t", + workflow_run_id="run", + workflow_run_elapsed_time=1.0, + workflow_run_status="s", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="v", + total_tokens=0, + file_list=[], + query="", + metadata={}, + ), + "workflow_trace", + ), + ( + MessageTraceInfo( + message_id="msg", + message_data={}, + inputs={}, + outputs={}, + start_time=None, + end_time=None, + conversation_mode="chat", + conversation_model="gpt-3.5-turbo", + message_tokens=0, + answer_tokens=0, + total_tokens=0, + metadata={}, + ), + "message_trace", + ), + ( + ModerationTraceInfo( + flagged=False, action="a", preset_response="p", query="q", metadata={}, message_id="m" + ), + None, + ), # Pass + ( + SuggestedQuestionTraceInfo( + suggested_question=[], + level="l", + total_tokens=0, + metadata={}, + message_id="m", + message_data={}, + inputs={}, + start_time=None, + end_time=None, + ), + "suggested_question_trace", + ), + ( + DatasetRetrievalTraceInfo( + metadata={}, + message_id="m", + message_data={}, + inputs={}, + documents=[], + start_time=None, + end_time=None, + ), + "dataset_retrieval_trace", + ), + ( + ToolTraceInfo( + tool_name="t", + tool_inputs={}, + tool_outputs="", + tool_config={}, + tool_parameters={}, + time_cost=0, + metadata={}, + message_id="m", + inputs={}, + outputs={}, + start_time=None, + end_time=None, + ), + "tool_trace", + ), + ( + GenerateNameTraceInfo( + tenant_id="t", metadata={}, message_id="m", inputs={}, outputs={}, start_time=None, end_time=None + ), + None, + ), # Pass + ] + + for trace_info, method_name in methods: + if method_name: + with patch.object(tencent_data_trace, method_name) as mock_method: + tencent_data_trace.trace(trace_info) + mock_method.assert_called_once_with(trace_info) + else: + tencent_data_trace.trace(trace_info) + + def test_api_check(self, tencent_data_trace): + tencent_data_trace.trace_client.api_check.return_value = True + assert tencent_data_trace.api_check() is True + tencent_data_trace.trace_client.api_check.assert_called_once() + + def test_get_project_url(self, tencent_data_trace): + tencent_data_trace.trace_client.get_project_url.return_value = "http://url" + assert tencent_data_trace.get_project_url() == "http://url" + tencent_data_trace.trace_client.get_project_url.assert_called_once() + + def test_workflow_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run-id" + trace_info.trace_id = "parent-trace-id" + + mock_trace_utils.convert_to_trace_id.return_value = 123 + mock_trace_utils.create_link.return_value = "link" + + with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"): + with patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc: + with patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur: + mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()] + + tencent_data_trace.workflow_trace(trace_info) + + mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id") + mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") + mock_span_builder.build_workflow_spans.assert_called_once() + assert tencent_data_trace.trace_client.add_span.call_count == 2 + mock_proc.assert_called_once_with(trace_info, 123) + mock_dur.assert_called_once_with(trace_info) + + def test_workflow_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run-id" + + with patch( + "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + ): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.workflow_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace") + + def test_message_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.message_id = "msg-id" + trace_info.trace_id = "parent-trace-id" + + mock_trace_utils.convert_to_trace_id.return_value = 123 + mock_trace_utils.create_link.return_value = "link" + + with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"): + with patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics: + with patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur: + mock_span_builder.build_message_span.return_value = MagicMock() + + tencent_data_trace.message_trace(trace_info) + + mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id") + mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") + mock_span_builder.build_message_span.assert_called_once() + tencent_data_trace.trace_client.add_span.assert_called_once() + mock_metrics.assert_called_once_with(trace_info) + mock_dur.assert_called_once_with(trace_info) + + def test_message_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + + with patch( + "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + ): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.message_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace") + + def test_tool_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): + trace_info = MagicMock(spec=ToolTraceInfo) + trace_info.message_id = "msg-id" + + mock_trace_utils.convert_to_span_id.return_value = 456 + mock_trace_utils.convert_to_trace_id.return_value = 123 + + tencent_data_trace.tool_trace(trace_info) + + mock_trace_utils.convert_to_span_id.assert_called_once_with("msg-id", "message") + mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id") + mock_span_builder.build_tool_span.assert_called_once_with(trace_info, 123, 456) + tencent_data_trace.trace_client.add_span.assert_called_once() + + def test_tool_trace_no_msg_id(self, tencent_data_trace): + trace_info = MagicMock(spec=ToolTraceInfo) + trace_info.message_id = None + + tencent_data_trace.tool_trace(trace_info) + tencent_data_trace.trace_client.add_span.assert_not_called() + + def test_tool_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=ToolTraceInfo) + trace_info.message_id = "msg-id" + + with patch( + "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + ): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.tool_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace") + + def test_dataset_retrieval_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = "msg-id" + + mock_trace_utils.convert_to_span_id.return_value = 456 + mock_trace_utils.convert_to_trace_id.return_value = 123 + + tencent_data_trace.dataset_retrieval_trace(trace_info) + + mock_trace_utils.convert_to_span_id.assert_called_once_with("msg-id", "message") + mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id") + mock_span_builder.build_retrieval_span.assert_called_once_with(trace_info, 123, 456) + tencent_data_trace.trace_client.add_span.assert_called_once() + + def test_dataset_retrieval_trace_no_msg_id(self, tencent_data_trace): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = None + + tencent_data_trace.dataset_retrieval_trace(trace_info) + tencent_data_trace.trace_client.add_span.assert_not_called() + + def test_dataset_retrieval_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = "msg-id" + + with patch( + "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + ): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.dataset_retrieval_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace") + + def test_suggested_question_trace(self, tencent_data_trace): + trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) + with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log: + tencent_data_trace.suggested_question_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace") + + def test_suggested_question_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) + with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.suggested_question_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace") + + def test_process_workflow_nodes(self, tencent_data_trace, mock_trace_utils): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run-id" + mock_trace_utils.convert_to_span_id.return_value = 111 + + node1 = MagicMock(spec=WorkflowNodeExecution) + node1.id = "n1" + node1.node_type = NodeType.LLM + node2 = MagicMock(spec=WorkflowNodeExecution) + node2.id = "n2" + node2.node_type = NodeType.TOOL + + with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]): + with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]): + with patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics: + tencent_data_trace._process_workflow_nodes(trace_info, 123) + + assert tencent_data_trace.trace_client.add_span.call_count == 2 + mock_metrics.assert_called_once_with(node1) + + def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils): + trace_info = MagicMock(spec=WorkflowTraceInfo) + mock_trace_utils.convert_to_span_id.return_value = 111 + + node = MagicMock(spec=WorkflowNodeExecution) + node.id = "n1" + + with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]): + with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace._process_workflow_nodes(trace_info, 123) + # The exception should be caught by the outer handler since convert_to_span_id is called first + mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") + + def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils): + trace_info = MagicMock(spec=WorkflowTraceInfo) + mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error") + + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace._process_workflow_nodes(trace_info, 123) + mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") + + def test_build_workflow_node_span(self, tencent_data_trace, mock_span_builder): + trace_info = MagicMock(spec=WorkflowTraceInfo) + + nodes = [ + (NodeType.LLM, mock_span_builder.build_workflow_llm_span), + (NodeType.KNOWLEDGE_RETRIEVAL, mock_span_builder.build_workflow_retrieval_span), + (NodeType.TOOL, mock_span_builder.build_workflow_tool_span), + (NodeType.CODE, mock_span_builder.build_workflow_task_span), + ] + + for node_type, builder_method in nodes: + node = MagicMock(spec=WorkflowNodeExecution) + node.node_type = node_type + builder_method.return_value = "span" + + result = tencent_data_trace._build_workflow_node_span(node, 123, trace_info, 456) + + assert result == "span" + builder_method.assert_called_once_with(123, 456, trace_info, node) + + def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder): + node = MagicMock(spec=WorkflowNodeExecution) + node.node_type = NodeType.LLM + node.id = "n1" + mock_span_builder.build_workflow_llm_span.side_effect = Exception("error") + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456) + assert result is None + mock_log.assert_called_once() + + def test_get_workflow_node_executions(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"app_id": "app-1"} + trace_info.workflow_run_id = "run-1" + + app = MagicMock(spec=App) + app.id = "app-1" + app.created_by = "user-1" + + account = MagicMock(spec=Account) + account.id = "user-1" + + tenant_join = MagicMock(spec=TenantAccountJoin) + tenant_join.tenant_id = "tenant-1" + + mock_executions = [MagicMock()] + + with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + mock_db.engine = "engine" + with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + session = mock_session_ctx.return_value.__enter__.return_value + session.scalar.side_effect = [app, account] + session.query.return_value.filter_by.return_value.first.return_value = tenant_join + + with patch( + "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" + ) as mock_repo: + mock_repo.return_value.get_by_workflow_run.return_value = mock_executions + + results = tencent_data_trace._get_workflow_node_executions(trace_info) + + assert results == mock_executions + account.set_tenant_id.assert_called_once_with("tenant-1") + + def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {} + + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + results = tencent_data_trace._get_workflow_node_executions(trace_info) + assert results == [] + mock_log.assert_called_once() + + def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"app_id": "app-1"} + + with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + mock_db.init_app = MagicMock() # Ensure init_app is mocked + mock_db.engine = "engine" + with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + session = mock_session_ctx.return_value.__enter__.return_value + session.scalar.return_value = None + + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + results = tencent_data_trace._get_workflow_node_executions(trace_info) + assert results == [] + mock_log.assert_called_once() + + def test_get_user_id_workflow(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.tenant_id = "tenant-1" + trace_info.metadata = {"user_id": "user-1"} + + with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")): + with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + mock_db.init_app = MagicMock() + mock_db.engine = MagicMock() + + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "unknown" + + def test_get_user_id_only_user_id(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = {"user_id": "user-1"} + + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "user-1" + + def test_get_user_id_anonymous(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = {} + + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "anonymous" + + def test_get_user_id_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.tenant_id = "t" + trace_info.metadata = {"user_id": "u"} + + with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "unknown" + mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID") + + def test_record_llm_metrics_usage_in_process_data(self, tencent_data_trace): + node = MagicMock(spec=WorkflowNodeExecution) + node.process_data = { + "usage": { + "latency": 2.5, + "time_to_first_token": 0.5, + "time_to_generate": 2.0, + "prompt_tokens": 10, + "completion_tokens": 20, + }, + "model_provider": "openai", + "model_name": "gpt-4", + "model_mode": "chat", + } + node.outputs = {} + + tencent_data_trace._record_llm_metrics(node) + + tencent_data_trace.trace_client.record_llm_duration.assert_called_once() + tencent_data_trace.trace_client.record_time_to_first_token.assert_called_once() + tencent_data_trace.trace_client.record_time_to_generate.assert_called_once() + assert tencent_data_trace.trace_client.record_token_usage.call_count == 2 + + def test_record_llm_metrics_usage_in_outputs(self, tencent_data_trace): + node = MagicMock(spec=WorkflowNodeExecution) + node.process_data = {} + node.outputs = {"usage": {"latency": 1.0, "prompt_tokens": 5}} + + tencent_data_trace._record_llm_metrics(node) + tencent_data_trace.trace_client.record_llm_duration.assert_called_once() + tencent_data_trace.trace_client.record_token_usage.assert_called_once() + + def test_record_llm_metrics_exception(self, tencent_data_trace): + node = MagicMock(spec=WorkflowNodeExecution) + node.process_data = None + node.outputs = None + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + tencent_data_trace._record_llm_metrics(node) + # Should not crash + + def test_record_message_llm_metrics(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = {"ls_provider": "openai", "ls_model_name": "gpt-4"} + trace_info.message_data = {"provider_response_latency": 1.1} + trace_info.is_streaming_request = True + trace_info.gen_ai_server_time_to_first_token = 0.2 + trace_info.llm_streaming_time_to_generate = 0.9 + trace_info.message_tokens = 15 + trace_info.answer_tokens = 25 + + tencent_data_trace._record_message_llm_metrics(trace_info) + + tencent_data_trace.trace_client.record_llm_duration.assert_called_once() + tencent_data_trace.trace_client.record_time_to_first_token.assert_called_once() + tencent_data_trace.trace_client.record_time_to_generate.assert_called_once() + assert tencent_data_trace.trace_client.record_token_usage.call_count == 2 + + def test_record_message_llm_metrics_object_data(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = {} + msg_data = MagicMock() + msg_data.provider_response_latency = 1.1 + msg_data.model_provider = "anthropic" + msg_data.model_id = "claude" + trace_info.message_data = msg_data + trace_info.is_streaming_request = False + + tencent_data_trace._record_message_llm_metrics(trace_info) + tencent_data_trace.trace_client.record_llm_duration.assert_called_once() + + def test_record_message_llm_metrics_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = None + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + tencent_data_trace._record_message_llm_metrics(trace_info) + # Should not crash + + def test_record_workflow_trace_duration(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + from datetime import datetime, timedelta + + now = datetime.now() + trace_info.start_time = now + trace_info.end_time = now + timedelta(seconds=3) + trace_info.workflow_run_status = "succeeded" + trace_info.conversation_id = "conv-1" + + # Mock the record_trace_duration method to capture arguments + with patch.object(tencent_data_trace.trace_client, "record_trace_duration") as mock_record: + tencent_data_trace._record_workflow_trace_duration(trace_info) + + # Assert the method was called once + mock_record.assert_called_once() + + # Extract arguments passed to the method + args, kwargs = mock_record.call_args + + # Validate the duration argument + assert args[0] == 3.0 + + # Validate the attributes dict in kwargs + attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {} + assert attributes["conversation_mode"] == "workflow" + assert attributes["has_conversation"] == "true" + + def test_record_workflow_trace_duration_fallback(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.start_time = None + trace_info.workflow_run_elapsed_time = 4.5 + trace_info.workflow_run_status = "failed" + trace_info.conversation_id = None + + with patch.object(tencent_data_trace.trace_client, "record_trace_duration") as mock_record: + tencent_data_trace._record_workflow_trace_duration(trace_info) + mock_record.assert_called_once() + args, kwargs = mock_record.call_args + assert args[0] == 4.5 + # Check attributes dict (either in kwargs or as second positional arg) + attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {} + assert attributes["has_conversation"] == "false" + + def test_record_workflow_trace_duration_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + tencent_data_trace._record_workflow_trace_duration(trace_info) + + def test_record_message_trace_duration(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + from datetime import datetime, timedelta + + now = datetime.now() + trace_info.start_time = now + trace_info.end_time = now + timedelta(seconds=2) + trace_info.conversation_mode = "chat" + trace_info.is_streaming_request = True + + tencent_data_trace._record_message_trace_duration(trace_info) + tencent_data_trace.trace_client.record_trace_duration.assert_called_once_with( + 2.0, {"conversation_mode": "chat", "stream": "true"} + ) + + def test_record_message_trace_duration_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.start_time = None + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + tencent_data_trace._record_message_trace_duration(trace_info) + + def test_del(self, tencent_data_trace): + client = tencent_data_trace.trace_client + tencent_data_trace.__del__() + client.shutdown.assert_called_once() + + def test_del_exception(self, tencent_data_trace): + tencent_data_trace.trace_client.shutdown.side_effect = Exception("error") + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.__del__() + mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup") diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py new file mode 100644 index 0000000000..ef28d18e20 --- /dev/null +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py @@ -0,0 +1,106 @@ +"""Unit tests for Tencent APM tracing utilities.""" + +from __future__ import annotations + +import hashlib +import uuid +from datetime import UTC, datetime +from unittest.mock import patch + +import pytest +from opentelemetry.trace import Link, TraceFlags + +from core.ops.tencent_trace.utils import TencentTraceUtils + + +def test_convert_to_trace_id_with_valid_uuid() -> None: + uuid_str = "12345678-1234-5678-1234-567812345678" + assert TencentTraceUtils.convert_to_trace_id(uuid_str) == uuid.UUID(uuid_str).int + + +def test_convert_to_trace_id_uses_uuid4_when_none() -> None: + expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int + uuid4_mock.assert_called_once() + + +def test_convert_to_trace_id_raises_value_error_for_invalid_uuid() -> None: + with pytest.raises(ValueError, match=r"^Invalid UUID input:"): + TencentTraceUtils.convert_to_trace_id("not-a-uuid") + + +def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None: + uuid_str = "12345678-1234-5678-1234-567812345678" + span_type = "llm" + + uuid_obj = uuid.UUID(uuid_str) + combined_key = f"{uuid_obj.hex}-{span_type}" + hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest() + expected = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False) + + assert TencentTraceUtils.convert_to_span_id(uuid_str, span_type) == expected + assert TencentTraceUtils.convert_to_span_id(uuid_str, "other") != expected + + +def test_convert_to_span_id_uses_uuid4_when_none() -> None: + expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") + with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + span_id = TencentTraceUtils.convert_to_span_id(None, "workflow") + assert isinstance(span_id, int) + uuid4_mock.assert_called_once() + + +def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None: + with pytest.raises(ValueError, match=r"^Invalid UUID input:"): + TencentTraceUtils.convert_to_span_id("bad-uuid", "span") + + +def test_generate_span_id_skips_invalid_span_id() -> None: + with patch( + "core.ops.tencent_trace.utils.random.getrandbits", + side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42], + ) as bits_mock: + assert TencentTraceUtils.generate_span_id() == 42 + assert bits_mock.call_count == 2 + + +def test_convert_datetime_to_nanoseconds_accepts_datetime() -> None: + start_time = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + expected = int(start_time.timestamp() * 1e9) + assert TencentTraceUtils.convert_datetime_to_nanoseconds(start_time) == expected + + +def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None: + fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) + expected = int(fixed.timestamp() * 1e9) + + with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock: + datetime_mock.now.return_value = fixed + assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected + datetime_mock.now.assert_called_once() + + +@pytest.mark.parametrize( + ("trace_id_str", "expected_trace_id"), + [ + ("0" * 31 + "1", int("0" * 31 + "1", 16)), + (str(uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc")), uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc").int), + ], +) +def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: int) -> None: + link = TencentTraceUtils.create_link(trace_id_str) + assert isinstance(link, Link) + assert link.context.trace_id == expected_trace_id + assert link.context.span_id == TencentTraceUtils.INVALID_SPAN_ID + assert link.context.is_remote is False + assert link.context.trace_flags == TraceFlags(TraceFlags.SAMPLED) + + +@pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None]) +def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None: + fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd") + with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: + link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type] + assert link.context.trace_id == fallback_uuid.int + uuid4_mock.assert_called_once() diff --git a/api/tests/unit_tests/core/ops/test_base_trace_instance.py b/api/tests/unit_tests/core/ops/test_base_trace_instance.py new file mode 100644 index 0000000000..a8bee7dfa7 --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_base_trace_instance.py @@ -0,0 +1,112 @@ +from unittest.mock import MagicMock + +import pytest +from sqlalchemy.orm import Session + +from core.ops.base_trace_instance import BaseTraceInstance +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.entities.trace_entity import BaseTraceInfo +from models import Account, App, TenantAccountJoin + + +class ConcreteTraceInstance(BaseTraceInstance): + def __init__(self, trace_config: BaseTracingConfig): + super().__init__(trace_config) + + def trace(self, trace_info: BaseTraceInfo): + super().trace(trace_info) + + +@pytest.fixture +def mock_db_session(monkeypatch): + mock_session = MagicMock(spec=Session) + mock_session.__enter__.return_value = mock_session + mock_session.__exit__.return_value = None + + mock_session_class = MagicMock(return_value=mock_session) + + monkeypatch.setattr("core.ops.base_trace_instance.Session", mock_session_class) + monkeypatch.setattr("core.ops.base_trace_instance.db", MagicMock()) + return mock_session + + +def test_get_service_account_with_tenant_app_not_found(mock_db_session): + mock_db_session.scalar.return_value = None + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + with pytest.raises(ValueError, match="App with id some_app_id not found"): + instance.get_service_account_with_tenant("some_app_id") + + +def test_get_service_account_with_tenant_no_creator(mock_db_session): + mock_app = MagicMock(spec=App) + mock_app.id = "some_app_id" + mock_app.created_by = None + mock_db_session.scalar.return_value = mock_app + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + with pytest.raises(ValueError, match="App with id some_app_id has no creator"): + instance.get_service_account_with_tenant("some_app_id") + + +def test_get_service_account_with_tenant_creator_not_found(mock_db_session): + mock_app = MagicMock(spec=App) + mock_app.id = "some_app_id" + mock_app.created_by = "creator_id" + + # First call to scalar returns app, second returns None (for account) + mock_db_session.scalar.side_effect = [mock_app, None] + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + with pytest.raises(ValueError, match="Creator account with id creator_id not found for app some_app_id"): + instance.get_service_account_with_tenant("some_app_id") + + +def test_get_service_account_with_tenant_tenant_not_found(mock_db_session): + mock_app = MagicMock(spec=App) + mock_app.id = "some_app_id" + mock_app.created_by = "creator_id" + + mock_account = MagicMock(spec=Account) + mock_account.id = "creator_id" + + mock_db_session.scalar.side_effect = [mock_app, mock_account] + + # session.query(TenantAccountJoin).filter_by(...).first() returns None + mock_db_session.query.return_value.filter_by.return_value.first.return_value = None + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + with pytest.raises(ValueError, match="Current tenant not found for account creator_id"): + instance.get_service_account_with_tenant("some_app_id") + + +def test_get_service_account_with_tenant_success(mock_db_session): + mock_app = MagicMock(spec=App) + mock_app.id = "some_app_id" + mock_app.created_by = "creator_id" + + mock_account = MagicMock(spec=Account) + mock_account.id = "creator_id" + mock_account.set_tenant_id = MagicMock() + + mock_db_session.scalar.side_effect = [mock_app, mock_account] + + mock_tenant_join = MagicMock(spec=TenantAccountJoin) + mock_tenant_join.tenant_id = "tenant_id" + mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_tenant_join + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + result = instance.get_service_account_with_tenant("some_app_id") + + assert result == mock_account + mock_account.set_tenant_id.assert_called_once_with("tenant_id") diff --git a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py new file mode 100644 index 0000000000..2d325ccb0e --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py @@ -0,0 +1,576 @@ +import contextlib +import json +import queue +from datetime import datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.ops.ops_trace_manager import ( + OpsTraceManager, + TraceQueueManager, + TraceTask, + TraceTaskName, +) + + +class DummyConfig: + def __init__(self, **kwargs): + self._data = kwargs + + def model_dump(self): + return dict(self._data) + + +class DummyTraceInstance: + instances: list["DummyTraceInstance"] = [] + + def __init__(self, config): + self.config = config + DummyTraceInstance.instances.append(self) + + def api_check(self): + return True + + def get_project_key(self): + return "fake-key" + + def get_project_url(self): + return "https://project.fake" + + +FAKE_PROVIDER_ENTRY = { + "config_class": DummyConfig, + "secret_keys": ["secret_value"], + "other_keys": ["other_value"], + "trace_instance": DummyTraceInstance, +} + + +class FakeProviderMap: + def __init__(self, data): + self._data = data + + def __getitem__(self, key): + if key in self._data: + return self._data[key] + raise KeyError(f"Unsupported tracing provider: {key}") + + +class DummyTimer: + def __init__(self, interval, function): + self.interval = interval + self.function = function + self.name = "" + self.daemon = False + self.started = False + + def start(self): + self.started = True + + def is_alive(self): + return False + + +class FakeMessageFile: + def __init__(self): + self.url = "path/to/file" + self.id = "file-id" + self.type = "document" + self.created_by_role = "role" + self.created_by = "user" + + +def make_message_data(**overrides): + created_at = datetime(2025, 2, 20, 12, 0, 0) + base = { + "id": "msg-id", + "conversation_id": "conv-id", + "created_at": created_at, + "updated_at": created_at + timedelta(seconds=3), + "message": "hello", + "provider_response_latency": 1, + "message_tokens": 5, + "answer_tokens": 7, + "answer": "world", + "error": "", + "status": "complete", + "model_provider": "provider", + "model_id": "model", + "from_end_user_id": "end-user", + "from_account_id": "account", + "agent_based": False, + "workflow_run_id": "workflow-run", + "from_source": "source", + "message_metadata": json.dumps({"usage": {"time_to_first_token": 1, "time_to_generate": 2}}), + "agent_thoughts": [], + "query": "sample-query", + "inputs": "sample-input", + } + base.update(overrides) + + class MessageData: + def __init__(self, data): + self.__dict__.update(data) + + def to_dict(self): + return dict(self.__dict__) + + return MessageData(base) + + +def make_agent_thought(tool_name, created_at): + return SimpleNamespace( + tools=[tool_name], + created_at=created_at, + tool_meta={ + tool_name: { + "tool_config": {"foo": "bar"}, + "time_cost": 5, + "error": "", + "tool_parameters": {"x": 1}, + } + }, + ) + + +def make_workflow_run(): + return SimpleNamespace( + workflow_id="wf-1", + tenant_id="tenant", + id="run-id", + elapsed_time=10, + status="finished", + inputs_dict={"sys.file": ["f1"], "query": "search"}, + outputs_dict={"out": "value"}, + version="3", + error=None, + total_tokens=12, + workflow_run_id="run-id", + created_at=datetime(2025, 2, 20, 10, 0, 0), + finished_at=datetime(2025, 2, 20, 10, 0, 5), + triggered_from="user", + app_id="app-id", + to_dict=lambda self=None: {"run": "value"}, + ) + + +def configure_db_query(session, *, message_file=None, workflow_app_log=None): + def _side_effect(model): + query = MagicMock() + query.filter_by.return_value.first.return_value = None + if message_file and model.__name__ == "MessageFile": + query.filter_by.return_value.first.return_value = message_file + if workflow_app_log and model.__name__ == "WorkflowAppLog": + query.filter_by.return_value.first.return_value = workflow_app_log + return query + + session.query.side_effect = _side_effect + + +class DummySessionContext: + scalar_values = [] + + def __init__(self, engine): + self._values = list(self.scalar_values) + self._index = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def scalar(self, *args, **kwargs): + if self._index >= len(self._values): + return None + value = self._values[self._index] + self._index += 1 + return value + + +@pytest.fixture(autouse=True) +def patch_provider_map(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({"dummy": FAKE_PROVIDER_ENTRY}) + ) + OpsTraceManager.ops_trace_instances_cache.clear() + OpsTraceManager.decrypted_configs_cache.clear() + + +@pytest.fixture(autouse=True) +def patch_timer_and_current_app(monkeypatch): + monkeypatch.setattr("core.ops.ops_trace_manager.threading.Timer", DummyTimer) + monkeypatch.setattr("core.ops.ops_trace_manager.trace_manager_queue", queue.Queue()) + monkeypatch.setattr("core.ops.ops_trace_manager.trace_manager_timer", None) + + class FakeApp: + def app_context(self): + return contextlib.nullcontext() + + fake_current = MagicMock() + fake_current._get_current_object.return_value = FakeApp() + monkeypatch.setattr("core.ops.ops_trace_manager.current_app", fake_current) + + +@pytest.fixture(autouse=True) +def patch_sqlalchemy_session(monkeypatch): + monkeypatch.setattr("core.ops.ops_trace_manager.Session", DummySessionContext) + + +@pytest.fixture +def encryption_mocks(monkeypatch): + encrypt_mock = MagicMock(side_effect=lambda tenant, value: f"enc-{value}") + batch_decrypt_mock = MagicMock(side_effect=lambda tenant, values: [f"dec-{value}" for value in values]) + obfuscate_mock = MagicMock(side_effect=lambda value: f"ob-{value}") + monkeypatch.setattr("core.ops.ops_trace_manager.encrypt_token", encrypt_mock) + monkeypatch.setattr("core.ops.ops_trace_manager.batch_decrypt_token", batch_decrypt_mock) + monkeypatch.setattr("core.ops.ops_trace_manager.obfuscated_token", obfuscate_mock) + return encrypt_mock, batch_decrypt_mock, obfuscate_mock + + +@pytest.fixture +def mock_db(monkeypatch): + session = MagicMock() + session.scalars.return_value.all.return_value = ["chat"] + db_mock = MagicMock() + db_mock.session = session + db_mock.engine = MagicMock() + monkeypatch.setattr("core.ops.ops_trace_manager.db", db_mock) + return session + + +@pytest.fixture +def workflow_repo_fixture(monkeypatch): + repo = MagicMock() + repo.get_workflow_run_by_id_without_tenant.return_value = make_workflow_run() + monkeypatch.setattr(TraceTask, "_get_workflow_run_repo", classmethod(lambda cls: repo)) + return repo + + +@pytest.fixture +def trace_task_message(monkeypatch, mock_db): + message_data = make_message_data() + monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data) + configure_db_query(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) + return message_data + + +def test_encrypt_tracing_config_handles_star_and_encrypt(encryption_mocks): + encrypted = OpsTraceManager.encrypt_tracing_config( + "tenant", + "dummy", + {"secret_value": "value", "other_value": "info"}, + current_trace_config={"secret_value": "keep"}, + ) + assert encrypted["secret_value"] == "enc-value" + assert encrypted["other_value"] == "info" + + +def test_encrypt_tracing_config_preserves_star(encryption_mocks): + encrypted = OpsTraceManager.encrypt_tracing_config( + "tenant", + "dummy", + {"secret_value": "*", "other_value": "info"}, + current_trace_config={"secret_value": "keep"}, + ) + assert encrypted["secret_value"] == "keep" + + +def test_decrypt_tracing_config_caches(encryption_mocks): + _, decrypt_mock, _ = encryption_mocks + payload = {"secret_value": "enc", "other_value": "info"} + first = OpsTraceManager.decrypt_tracing_config("tenant", "dummy", payload) + second = OpsTraceManager.decrypt_tracing_config("tenant", "dummy", payload) + assert first == second + assert decrypt_mock.call_count == 1 + + +def test_obfuscated_decrypt_token(encryption_mocks): + _, _, obfuscate_mock = encryption_mocks + result = OpsTraceManager.obfuscated_decrypt_token("dummy", {"secret_value": "value", "other_value": "info"}) + assert "secret_value" in result + assert result["secret_value"] == "ob-value" + obfuscate_mock.assert_called_once() + + +def test_get_decrypted_tracing_config_returns_config(encryption_mocks, mock_db): + trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc", "other_value": "info"}) + mock_db.query.return_value.where.return_value.first.return_value = trace_config_data + app = SimpleNamespace(id="app-id", tenant_id="tenant") + mock_db.scalar.return_value = app + + decrypted = OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") + assert decrypted["other_value"] == "info" + + +def test_get_decrypted_tracing_config_missing_trace_config(mock_db): + mock_db.query.return_value.where.return_value.first.return_value = None + assert OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") is None + + +def test_get_decrypted_tracing_config_raises_for_missing_app(mock_db): + trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc"}) + mock_db.query.return_value.where.return_value.first.return_value = trace_config_data + mock_db.scalar.return_value = None + with pytest.raises(ValueError, match="App not found"): + OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") + + +def test_get_decrypted_tracing_config_raises_for_none_config(mock_db): + trace_config_data = SimpleNamespace(tracing_config=None) + mock_db.query.return_value.where.return_value.first.return_value = trace_config_data + mock_db.scalar.return_value = SimpleNamespace(tenant_id="tenant") + with pytest.raises(ValueError, match="Tracing config cannot be None"): + OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") + + +def test_get_ops_trace_instance_handles_none_app(mock_db): + mock_db.query.return_value.where.return_value.first.return_value = None + assert OpsTraceManager.get_ops_trace_instance("app-id") is None + + +def test_get_ops_trace_instance_returns_none_when_disabled(mock_db, monkeypatch): + app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": False})) + mock_db.query.return_value.where.return_value.first.return_value = app + assert OpsTraceManager.get_ops_trace_instance("app-id") is None + + +def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch): + app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "missing"})) + mock_db.query.return_value.where.return_value.first.return_value = app + monkeypatch.setattr("core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({})) + assert OpsTraceManager.get_ops_trace_instance("app-id") is None + + +def test_get_ops_trace_instance_success(monkeypatch, mock_db): + app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"})) + mock_db.query.return_value.where.return_value.first.return_value = app + monkeypatch.setattr( + "core.ops.ops_trace_manager.OpsTraceManager.get_decrypted_tracing_config", + classmethod(lambda cls, aid, provider: {"secret_value": "decrypted", "other_value": "info"}), + ) + instance = OpsTraceManager.get_ops_trace_instance("app-id") + assert instance is not None + cached_instance = OpsTraceManager.get_ops_trace_instance("app-id") + assert instance is cached_instance + + +def test_get_app_config_through_message_id_returns_none(mock_db): + mock_db.scalar.return_value = None + assert OpsTraceManager.get_app_config_through_message_id("m") is None + + +def test_get_app_config_through_message_id_prefers_override(mock_db): + message = SimpleNamespace(conversation_id="conv") + conversation = SimpleNamespace(app_model_config_id=None, override_model_configs={"foo": "bar"}) + app_config = SimpleNamespace(id="config-id") + mock_db.scalar.side_effect = [message, conversation] + result = OpsTraceManager.get_app_config_through_message_id("m") + assert result == {"foo": "bar"} + + +def test_get_app_config_through_message_id_app_model_config(mock_db): + message = SimpleNamespace(conversation_id="conv") + conversation = SimpleNamespace(app_model_config_id="cfg", override_model_configs=None) + mock_db.scalar.side_effect = [message, conversation, SimpleNamespace(id="cfg")] + result = OpsTraceManager.get_app_config_through_message_id("m") + assert result.id == "cfg" + + +def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch): + mock_db.query.return_value.where.return_value.first.return_value = None + with pytest.raises(ValueError, match="Invalid tracing provider"): + OpsTraceManager.update_app_tracing_config("app", True, "bad") + with pytest.raises(ValueError, match="App not found"): + OpsTraceManager.update_app_tracing_config("app", True, None) + + +def test_update_app_tracing_config_success(mock_db): + app = SimpleNamespace(id="app-id", tracing="{}") + mock_db.query.return_value.where.return_value.first.return_value = app + OpsTraceManager.update_app_tracing_config("app-id", True, "dummy") + assert app.tracing is not None + mock_db.commit.assert_called_once() + + +def test_get_app_tracing_config_errors_when_missing(mock_db): + mock_db.query.return_value.where.return_value.first.return_value = None + with pytest.raises(ValueError, match="App not found"): + OpsTraceManager.get_app_tracing_config("app") + + +def test_get_app_tracing_config_returns_defaults(mock_db): + mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=None) + assert OpsTraceManager.get_app_tracing_config("app-id") == {"enabled": False, "tracing_provider": None} + + +def test_get_app_tracing_config_returns_payload(mock_db): + payload = {"enabled": True, "tracing_provider": "dummy"} + mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=json.dumps(payload)) + assert OpsTraceManager.get_app_tracing_config("app-id") == payload + + +def test_check_and_project_helpers(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.provider_config_map", + FakeProviderMap( + { + "dummy": { + "config_class": DummyConfig, + "trace_instance": type( + "Trace", + (), + { + "__init__": lambda self, cfg: None, + "api_check": lambda self: True, + "get_project_key": lambda self: "key", + "get_project_url": lambda self: "url", + }, + ), + "secret_keys": [], + "other_keys": [], + } + } + ), + ) + assert OpsTraceManager.check_trace_config_is_effective({}, "dummy") + assert OpsTraceManager.get_trace_config_project_key({}, "dummy") == "key" + assert OpsTraceManager.get_trace_config_project_url({}, "dummy") == "url" + + +def test_trace_task_conversation_and_extract(monkeypatch): + task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE, message_id="msg") + assert task.conversation_trace(foo="bar") == {"foo": "bar"} + assert task._extract_streaming_metrics(make_message_data(message_metadata="not json")) == {} + + +def test_trace_task_message_trace(trace_task_message, mock_db): + task = TraceTask(trace_type=TraceTaskName.MESSAGE_TRACE, message_id="msg-id") + result = task.message_trace("msg-id") + assert result.message_id == "msg-id" + + +def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db): + DummySessionContext.scalar_values = ["wf-app-log", "message-ref"] + execution = SimpleNamespace(id_="run-id") + task = TraceTask( + trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user" + ) + result = task.workflow_trace(workflow_run_id="run-id", conversation_id="conv", user_id="user") + assert result.workflow_run_id == "run-id" + assert result.workflow_id == "wf-1" + + +def test_trace_task_moderation_trace(trace_task_message): + task = TraceTask(trace_type=TraceTaskName.MODERATION_TRACE, message_id="msg-id") + moderation_result = SimpleNamespace(action="block", preset_response="no", query="q", flagged=True) + timer = {"start": 1, "end": 2} + result = task.moderation_trace("msg-id", timer, moderation_result=moderation_result, inputs={"src": "payload"}) + assert result.flagged is True + assert result.message_id == "log-id" + + +def test_trace_task_suggested_question_trace(trace_task_message): + task = TraceTask(trace_type=TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id="msg-id") + timer = {"start": 1, "end": 2} + result = task.suggested_question_trace("msg-id", timer, suggested_question=["q1"]) + assert result.message_id == "log-id" + assert "suggested_question" in result.__dict__ + + +def test_trace_task_dataset_retrieval_trace(trace_task_message): + task = TraceTask(trace_type=TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id="msg-id") + timer = {"start": 1, "end": 2} + mock_doc = SimpleNamespace(model_dump=lambda: {"doc": "value"}) + result = task.dataset_retrieval_trace("msg-id", timer, documents=[mock_doc]) + assert result.documents == [{"doc": "value"}] + + +def test_trace_task_tool_trace(monkeypatch, mock_db): + custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))]) + monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message) + configure_db_query(mock_db, message_file=FakeMessageFile()) + task = TraceTask(trace_type=TraceTaskName.TOOL_TRACE, message_id="msg-id") + timer = {"start": 1, "end": 5} + result = task.tool_trace("msg-id", timer, tool_name="tool-a", tool_inputs={"foo": 1}, tool_outputs="result") + assert result.tool_name == "tool-a" + assert result.time_cost == 5 + + +def test_trace_task_generate_name_trace(): + task = TraceTask(trace_type=TraceTaskName.GENERATE_NAME_TRACE, conversation_id="conv-id") + timer = {"start": 1, "end": 2} + assert task.generate_name_trace("conv-id", timer, tenant_id=None) == {} + result = task.generate_name_trace( + "conv-id", timer, tenant_id="tenant", generate_conversation_name="name", inputs="q" + ) + assert result.outputs == "name" + assert result.tenant_id == "tenant" + + +def test_extract_streaming_metrics_invalid_json(): + task = TraceTask(trace_type=TraceTaskName.MESSAGE_TRACE, message_id="msg-id") + fake_message = make_message_data(message_metadata="invalid") + assert task._extract_streaming_metrics(fake_message) == {} + + +def test_trace_queue_manager_add_and_collect(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True) + ) + manager = TraceQueueManager(app_id="app-id", user_id="user") + task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE) + manager.add_trace_task(task) + tasks = manager.collect_tasks() + assert tasks == [task] + + +def test_trace_queue_manager_run_invokes_send(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True) + ) + manager = TraceQueueManager(app_id="app-id", user_id="user") + task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE) + called = {} + + def fake_collect(): + return [task] + + def fake_send(tasks): + called["tasks"] = tasks + + monkeypatch.setattr(TraceQueueManager, "collect_tasks", lambda self: fake_collect()) + monkeypatch.setattr(TraceQueueManager, "send_to_celery", lambda self, t: fake_send(t)) + manager.run() + assert called["tasks"] == [task] + + +def test_trace_queue_manager_send_to_celery(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True) + ) + storage_save = MagicMock() + process_delay = MagicMock() + monkeypatch.setattr("core.ops.ops_trace_manager.storage.save", storage_save) + monkeypatch.setattr("core.ops.ops_trace_manager.process_trace_tasks.delay", process_delay) + monkeypatch.setattr("core.ops.ops_trace_manager.uuid4", MagicMock(return_value=SimpleNamespace(hex="file-123"))) + + manager = TraceQueueManager(app_id="app-id", user_id="user") + + class DummyTraceInfo: + def model_dump(self): + return {"trace": "info"} + + class DummyTask: + def __init__(self): + self.app_id = "app-id" + + def execute(self): + return DummyTraceInfo() + + task = DummyTask() + manager.send_to_celery([task]) + storage_save.assert_called_once() + process_delay.assert_called_once_with({"file_id": "file-123", "app_id": "app-id"}) diff --git a/api/tests/unit_tests/core/ops/test_utils.py b/api/tests/unit_tests/core/ops/test_utils.py index e1084001b7..8a89422782 100644 --- a/api/tests/unit_tests/core/ops/test_utils.py +++ b/api/tests/unit_tests/core/ops/test_utils.py @@ -1,9 +1,20 @@ import re from datetime import datetime +from unittest.mock import MagicMock, patch import pytest -from core.ops.utils import generate_dotted_order, validate_project_name, validate_url, validate_url_with_path +from core.ops.utils import ( + filter_none_values, + generate_dotted_order, + get_message_data, + measure_time, + replace_text_with_content, + validate_integer_id, + validate_project_name, + validate_url, + validate_url_with_path, +) class TestValidateUrl: @@ -187,3 +198,92 @@ class TestGenerateDottedOrder: result = generate_dotted_order(run_id, start_time, None) assert "." not in result + + def test_dotted_order_with_string_start_time(self): + """Test dotted_order generation with string start_time.""" + start_time = "2025-12-23T04:19:55.111000" + run_id = "test-run-id" + result = generate_dotted_order(run_id, start_time) + + assert result == "20251223T041955111000Ztest-run-id" + + +class TestFilterNoneValues: + """Test cases for filter_none_values function""" + + def test_filter_none_values(self): + data = {"a": 1, "b": None, "c": "test", "d": datetime(2025, 1, 1, 12, 0, 0)} + result = filter_none_values(data) + assert result == {"a": 1, "c": "test", "d": "2025-01-01T12:00:00"} + + def test_filter_none_values_empty(self): + assert filter_none_values({}) == {} + + +class TestGetMessageData: + """Test cases for get_message_data function""" + + @patch("core.ops.utils.db") + @patch("core.ops.utils.Message") + @patch("core.ops.utils.select") + def test_get_message_data(self, mock_select, mock_message, mock_db): + mock_scalar = mock_db.session.scalar + mock_msg_instance = MagicMock() + mock_scalar.return_value = mock_msg_instance + + result = get_message_data("message-id") + + assert result == mock_msg_instance + mock_select.assert_called_once() + mock_scalar.assert_called_once() + + +class TestMeasureTime: + """Test cases for measure_time function""" + + def test_measure_time(self): + with measure_time() as timing_info: + assert "start" in timing_info + assert isinstance(timing_info["start"], datetime) + assert timing_info["end"] is None + + assert timing_info["end"] is not None + assert isinstance(timing_info["end"], datetime) + assert timing_info["end"] >= timing_info["start"] + + +class TestReplaceTextWithContent: + """Test cases for replace_text_with_content function""" + + def test_replace_text_with_content_dict(self): + data = {"text": "hello", "other": "world"} + assert replace_text_with_content(data) == {"content": "hello", "other": "world"} + + def test_replace_text_with_content_nested(self): + data = {"text": "v1", "nested": {"text": "v2", "list": [{"text": "v3"}]}} + expected = {"content": "v1", "nested": {"content": "v2", "list": [{"content": "v3"}]}} + assert replace_text_with_content(data) == expected + + def test_replace_text_with_content_list(self): + data = [{"text": "v1"}, "v2"] + assert replace_text_with_content(data) == [{"content": "v1"}, "v2"] + + def test_replace_text_with_content_primitive(self): + assert replace_text_with_content(123) == 123 + assert replace_text_with_content("text") == "text" + + +class TestValidateIntegerId: + """Test cases for validate_integer_id function""" + + def test_valid_integer_id(self): + assert validate_integer_id("123") == "123" + assert validate_integer_id(" 456 ") == "456" + + def test_invalid_integer_id_raises_error(self): + with pytest.raises(ValueError, match="ID must be a valid integer"): + validate_integer_id("abc") + + def test_empty_integer_id_raises_error(self): + with pytest.raises(ValueError, match="ID must be a valid integer"): + validate_integer_id("") diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py new file mode 100644 index 0000000000..cdd97d5369 --- /dev/null +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -0,0 +1,1196 @@ +"""Comprehensive tests for core.ops.weave_trace.weave_trace module.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from weave.trace_server.trace_server_interface import TraceStatus + +from core.ops.entities.config_entity import WeaveConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) +from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel +from core.ops.weave_trace.weave_trace import WeaveDataTrace +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +def _make_weave_config(**overrides) -> WeaveConfig: + defaults = { + "api_key": "wv-api-key", + "project": "my-project", + "entity": "my-entity", + "host": None, + } + defaults.update(overrides) + return WeaveConfig(**defaults) + + +def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: + defaults = { + "workflow_id": "wf-id", + "tenant_id": "tenant-1", + "workflow_run_id": "run-1", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"key": "val"}, + "workflow_run_outputs": {"answer": "42"}, + "workflow_run_version": "v1", + "total_tokens": 10, + "file_list": [], + "query": "hello", + "metadata": {"user_id": "u1", "app_id": "app-1"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(overrides) + return WorkflowTraceInfo(**defaults) + + +def _make_message_trace_info(**overrides) -> MessageTraceInfo: + msg_data = MagicMock() + msg_data.id = "msg-1" + msg_data.from_account_id = "acc-1" + msg_data.from_end_user_id = None + defaults = { + "conversation_model": "chat", + "message_tokens": 5, + "answer_tokens": 10, + "total_tokens": 15, + "conversation_mode": "chat", + "metadata": {"conversation_id": "c1"}, + "message_id": "msg-1", + "message_data": msg_data, + "inputs": {"prompt": "hi"}, + "outputs": "ok", + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + "error": None, + } + defaults.update(overrides) + return MessageTraceInfo(**defaults) + + +def _make_moderation_trace_info(**overrides) -> ModerationTraceInfo: + defaults = { + "flagged": False, + "action": "allow", + "preset_response": "", + "query": "test", + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + } + defaults.update(overrides) + return ModerationTraceInfo(**defaults) + + +def _make_suggested_question_trace_info(**overrides) -> SuggestedQuestionTraceInfo: + defaults = { + "suggested_question": ["q1", "q2"], + "level": "info", + "total_tokens": 5, + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "message_data": SimpleNamespace(created_at=_dt(), updated_at=_dt()), + "inputs": {"i": 1}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + "error": None, + } + defaults.update(overrides) + return SuggestedQuestionTraceInfo(**defaults) + + +def _make_dataset_retrieval_trace_info(**overrides) -> DatasetRetrievalTraceInfo: + msg_data = MagicMock() + msg_data.created_at = _dt() + msg_data.updated_at = _dt() + defaults = { + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "message_data": msg_data, + "inputs": "query", + "documents": [{"content": "doc"}], + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(overrides) + return DatasetRetrievalTraceInfo(**defaults) + + +def _make_tool_trace_info(**overrides) -> ToolTraceInfo: + defaults = { + "tool_name": "my_tool", + "tool_inputs": {"x": 1}, + "tool_outputs": "output", + "tool_config": {"desc": "d"}, + "tool_parameters": {"p": "v"}, + "time_cost": 0.5, + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "inputs": {"i": "v"}, + "outputs": {"o": "v"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + "error": None, + } + defaults.update(overrides) + return ToolTraceInfo(**defaults) + + +def _make_generate_name_trace_info(**overrides) -> GenerateNameTraceInfo: + defaults = { + "tenant_id": "t1", + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "inputs": {"i": 1}, + "outputs": {"name": "test"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(overrides) + return GenerateNameTraceInfo(**defaults) + + +def _make_node(**overrides): + """Create a mock workflow node execution object.""" + defaults = { + "id": "node-1", + "title": "Node Title", + "node_type": NodeType.CODE, + "status": "succeeded", + "inputs": {"key": "value"}, + "outputs": {"result": "ok"}, + "created_at": _dt(), + "elapsed_time": 1.0, + "process_data": None, + "metadata": {}, + } + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +# ── Fixtures ───────────────────────────────────────────────────────────────── + + +@pytest.fixture +def mock_wandb(): + with patch("core.ops.weave_trace.weave_trace.wandb") as mock: + mock.login.return_value = True + yield mock + + +@pytest.fixture +def mock_weave(): + with patch("core.ops.weave_trace.weave_trace.weave") as mock: + client = MagicMock() + client.entity = "my-entity" + client.project = "my-project" + mock.init.return_value = client + yield mock, client + + +@pytest.fixture +def trace_instance(mock_wandb, mock_weave): + """Create a WeaveDataTrace instance with mocked wandb/weave.""" + _, weave_client = mock_weave + config = _make_weave_config() + instance = WeaveDataTrace(config) + return instance + + +@pytest.fixture +def trace_instance_with_host(mock_wandb, mock_weave): + """Create a WeaveDataTrace instance with host configured.""" + _, weave_client = mock_weave + config = _make_weave_config(host="https://my.wandb.host") + instance = WeaveDataTrace(config) + return instance + + +# ── TestInit ───────────────────────────────────────────────────────────────── + + +class TestInit: + def test_init_without_host(self, mock_wandb, mock_weave): + """Test __init__ calls wandb.login without host.""" + mock_w, weave_client = mock_weave + config = _make_weave_config(host=None) + instance = WeaveDataTrace(config) + + mock_wandb.login.assert_called_once_with(key="wv-api-key", verify=True, relogin=True) + mock_w.init.assert_called_once_with(project_name="my-entity/my-project") + assert instance.weave_api_key == "wv-api-key" + assert instance.project_name == "my-project" + assert instance.entity == "my-entity" + assert instance.calls == {} + + def test_init_with_host(self, mock_wandb, mock_weave): + """Test __init__ calls wandb.login with host.""" + config = _make_weave_config(host="https://my.wandb.host") + instance = WeaveDataTrace(config) + + mock_wandb.login.assert_called_once_with( + key="wv-api-key", verify=True, relogin=True, host="https://my.wandb.host" + ) + assert instance.host == "https://my.wandb.host" + + def test_init_without_entity(self, mock_wandb, mock_weave): + """Test __init__ initializes weave without entity prefix when entity is None.""" + mock_w, weave_client = mock_weave + config = _make_weave_config(entity=None) + instance = WeaveDataTrace(config) + + mock_w.init.assert_called_once_with(project_name="my-project") + + def test_init_login_failure_raises(self, mock_wandb, mock_weave): + """Test __init__ raises ValueError when wandb.login returns False.""" + mock_wandb.login.return_value = False + config = _make_weave_config() + + with pytest.raises(ValueError, match="Weave login failed"): + WeaveDataTrace(config) + + def test_init_files_url_from_env(self, mock_wandb, mock_weave, monkeypatch): + """Test FILES_URL is read from environment.""" + monkeypatch.setenv("FILES_URL", "http://files.example.com") + config = _make_weave_config() + instance = WeaveDataTrace(config) + assert instance.file_base_url == "http://files.example.com" + + def test_init_files_url_default(self, mock_wandb, mock_weave, monkeypatch): + """Test FILES_URL defaults to http://127.0.0.1:5001.""" + monkeypatch.delenv("FILES_URL", raising=False) + config = _make_weave_config() + instance = WeaveDataTrace(config) + assert instance.file_base_url == "http://127.0.0.1:5001" + + def test_project_id_set_correctly(self, trace_instance): + """Test that project_id is set from weave_client entity/project.""" + assert trace_instance.project_id == "my-entity/my-project" + + +# ── TestGetProjectUrl ───────────────────────────────────────────────────────── + + +class TestGetProjectUrl: + def test_get_project_url_with_entity(self, trace_instance): + """Returns wandb URL with entity/project.""" + url = trace_instance.get_project_url() + assert url == "https://wandb.ai/my-entity/my-project" + + def test_get_project_url_without_entity(self, mock_wandb, mock_weave): + """Returns wandb URL with project only when entity is None.""" + config = _make_weave_config(entity=None) + instance = WeaveDataTrace(config) + url = instance.get_project_url() + assert url == "https://wandb.ai/my-project" + + def test_get_project_url_exception_raises(self, trace_instance, monkeypatch): + """Raises ValueError when exception occurs in get_project_url.""" + monkeypatch.setattr(trace_instance, "entity", None) + monkeypatch.setattr(trace_instance, "project_name", None) + # Force an error by making string formatting fail + with patch("core.ops.weave_trace.weave_trace.logger") as mock_logger: + # Simulate exception via property + original_entity = trace_instance.entity + trace_instance.entity = None + trace_instance.project_name = None + url = trace_instance.get_project_url() + assert "https://wandb.ai/" in url + + +# ── TestTraceDispatcher ───────────────────────────────────────────────────── + + +class TestTraceDispatcher: + def test_dispatches_workflow_trace(self, trace_instance): + with patch.object(trace_instance, "workflow_trace") as mock_wt: + trace_instance.trace(_make_workflow_trace_info()) + mock_wt.assert_called_once() + + def test_dispatches_message_trace(self, trace_instance): + with patch.object(trace_instance, "message_trace") as mock_mt: + trace_instance.trace(_make_message_trace_info()) + mock_mt.assert_called_once() + + def test_dispatches_moderation_trace(self, trace_instance): + with patch.object(trace_instance, "moderation_trace") as mock_mod: + msg_data = MagicMock() + msg_data.created_at = _dt() + trace_instance.trace(_make_moderation_trace_info(message_data=msg_data)) + mock_mod.assert_called_once() + + def test_dispatches_suggested_question_trace(self, trace_instance): + with patch.object(trace_instance, "suggested_question_trace") as mock_sq: + trace_instance.trace(_make_suggested_question_trace_info()) + mock_sq.assert_called_once() + + def test_dispatches_dataset_retrieval_trace(self, trace_instance): + with patch.object(trace_instance, "dataset_retrieval_trace") as mock_dr: + trace_instance.trace(_make_dataset_retrieval_trace_info()) + mock_dr.assert_called_once() + + def test_dispatches_tool_trace(self, trace_instance): + with patch.object(trace_instance, "tool_trace") as mock_tool: + trace_instance.trace(_make_tool_trace_info()) + mock_tool.assert_called_once() + + def test_dispatches_generate_name_trace(self, trace_instance): + with patch.object(trace_instance, "generate_name_trace") as mock_gn: + trace_instance.trace(_make_generate_name_trace_info()) + mock_gn.assert_called_once() + + +# ── TestNormalizeTime ───────────────────────────────────────────────────────── + + +class TestNormalizeTime: + def test_none_returns_utc_now(self, trace_instance): + now_before = datetime.now(UTC) + result = trace_instance._normalize_time(None) + now_after = datetime.now(UTC) + assert result.tzinfo is not None + assert now_before <= result <= now_after + + def test_naive_datetime_gets_utc(self, trace_instance): + naive = datetime(2024, 6, 15, 12, 0, 0) + result = trace_instance._normalize_time(naive) + assert result.tzinfo == UTC + assert result.year == 2024 + assert result.month == 6 + + def test_aware_datetime_unchanged(self, trace_instance): + aware = datetime(2024, 6, 15, 12, 0, 0, tzinfo=UTC) + result = trace_instance._normalize_time(aware) + assert result == aware + assert result.tzinfo == UTC + + +# ── TestStartCall ───────────────────────────────────────────────────────────── + + +class TestStartCall: + def test_start_call_basic(self, trace_instance): + """Test basic start_call stores call metadata.""" + run = WeaveTraceModel( + id="run-1", + op="test-op", + inputs={"key": "val"}, + attributes={"trace_id": "t-1", "start_time": _dt()}, + ) + trace_instance.start_call(run) + + assert "run-1" in trace_instance.calls + assert trace_instance.calls["run-1"]["trace_id"] == "t-1" + assert trace_instance.calls["run-1"]["parent_id"] is None + trace_instance.weave_client.server.call_start.assert_called_once() + + def test_start_call_with_parent(self, trace_instance): + """Test start_call records parent_run_id.""" + run = WeaveTraceModel( + id="child-1", + op="child-op", + inputs={}, + attributes={"trace_id": "t-1", "start_time": _dt()}, + ) + trace_instance.start_call(run, parent_run_id="parent-1") + + assert trace_instance.calls["child-1"]["parent_id"] == "parent-1" + + def test_start_call_none_inputs_becomes_empty_dict(self, trace_instance): + """Test that None inputs is normalized to {}.""" + run = WeaveTraceModel( + id="run-2", + op="op", + inputs=None, + attributes={"trace_id": "t-2", "start_time": _dt()}, + ) + trace_instance.start_call(run) + call_args = trace_instance.weave_client.server.call_start.call_args + req = call_args[0][0] + assert req.start.inputs == {} + + def test_start_call_non_dict_inputs_becomes_str_dict(self, trace_instance): + """Test that non-dict inputs is wrapped as string.""" + run = WeaveTraceModel( + id="run-3", + op="op", + inputs="some string input", + attributes={"trace_id": "t-3", "start_time": _dt()}, + ) + trace_instance.start_call(run) + call_args = trace_instance.weave_client.server.call_start.call_args + req = call_args[0][0] + # String inputs gets converted by validator to a dict + assert isinstance(req.start.inputs, dict) + + def test_start_call_none_attributes_becomes_empty_dict(self, trace_instance): + """Test that None attributes is handled properly.""" + run = WeaveTraceModel( + id="run-4", + op="op", + inputs={}, + attributes=None, + ) + trace_instance.start_call(run) + # trace_id should fall back to run_data.id + assert trace_instance.calls["run-4"]["trace_id"] == "run-4" + + def test_start_call_non_dict_attributes_becomes_dict(self, trace_instance): + """Test that non-dict attributes is wrapped.""" + run = WeaveTraceModel( + id="run-5", + op="op", + inputs={}, + attributes=None, + ) + # Manually override after construction + run.attributes = "some-attr-string" + trace_instance.start_call(run) + call_args = trace_instance.weave_client.server.call_start.call_args + req = call_args[0][0] + assert isinstance(req.start.attributes, dict) + assert req.start.attributes == {"attributes": "some-attr-string"} + + def test_start_call_trace_id_falls_back_to_run_id(self, trace_instance): + """When trace_id not in attributes, falls back to run_data.id.""" + run = WeaveTraceModel( + id="run-6", + op="op", + inputs={}, + attributes={"start_time": _dt()}, + ) + trace_instance.start_call(run) + assert trace_instance.calls["run-6"]["trace_id"] == "run-6" + + +# ── TestFinishCall ────────────────────────────────────────────────────────── + + +class TestFinishCall: + def _setup_call(self, trace_instance, run_id="run-1", trace_id="t-1"): + """Helper: register a call so finish_call can find it.""" + trace_instance.calls[run_id] = {"trace_id": trace_id, "parent_id": None} + + def test_finish_call_success(self, trace_instance): + """Test finish_call sends call_end with SUCCESS status.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + outputs={"result": "ok"}, + attributes={"start_time": _dt(), "end_time": _dt() + timedelta(seconds=1)}, + exception=None, + ) + trace_instance.finish_call(run) + trace_instance.weave_client.server.call_end.assert_called_once() + call_args = trace_instance.weave_client.server.call_end.call_args + req = call_args[0][0] + assert req.end.summary["status_counts"][TraceStatus.SUCCESS] == 1 + assert req.end.summary["status_counts"][TraceStatus.ERROR] == 0 + assert req.end.exception is None + + def test_finish_call_with_error(self, trace_instance): + """Test finish_call sends call_end with ERROR status when exception is set.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + outputs={}, + attributes={"start_time": _dt(), "end_time": _dt() + timedelta(seconds=1)}, + exception="Something broke", + ) + trace_instance.finish_call(run) + call_args = trace_instance.weave_client.server.call_end.call_args + req = call_args[0][0] + assert req.end.summary["status_counts"][TraceStatus.ERROR] == 1 + assert req.end.summary["status_counts"][TraceStatus.SUCCESS] == 0 + assert req.end.exception == "Something broke" + + def test_finish_call_missing_id_raises(self, trace_instance): + """Test finish_call raises ValueError when call id not found.""" + run = WeaveTraceModel( + id="nonexistent", + op="op", + inputs={}, + ) + with pytest.raises(ValueError, match="Call with id nonexistent not found"): + trace_instance.finish_call(run) + + def test_finish_call_elapsed_negative_clamped_to_zero(self, trace_instance): + """Test that negative elapsed time is clamped to 0.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + attributes={ + "start_time": _dt() + timedelta(seconds=5), + "end_time": _dt(), # end before start + }, + ) + trace_instance.finish_call(run) + call_args = trace_instance.weave_client.server.call_end.call_args + req = call_args[0][0] + assert req.end.summary["weave"]["latency_ms"] == 0 + + def test_finish_call_none_attributes(self, trace_instance): + """Test finish_call handles None attributes.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + attributes=None, + ) + trace_instance.finish_call(run) + trace_instance.weave_client.server.call_end.assert_called_once() + + def test_finish_call_non_dict_attributes(self, trace_instance): + """Test finish_call handles non-dict attributes.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + attributes=None, + ) + run.attributes = "some string attr" + trace_instance.finish_call(run) + trace_instance.weave_client.server.call_end.assert_called_once() + + +# ── TestWorkflowTrace ───────────────────────────────────────────────────────── + + +class TestWorkflowTrace: + def _setup_repo(self, monkeypatch, nodes=None): + """Helper to patch session/repo dependencies.""" + if nodes is None: + nodes = [] + + repo = MagicMock() + repo.get_by_workflow_run.return_value = nodes + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + + monkeypatch.setattr("core.ops.weave_trace.weave_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + return repo + + def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch): + """Workflow trace with no nodes and no message_id.""" + self._setup_repo(monkeypatch, nodes=[]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # Only workflow run: start_call and finish_call each called once + assert trace_instance.start_call.call_count == 1 + assert trace_instance.finish_call.call_count == 1 + + def test_workflow_trace_with_message_id(self, trace_instance, monkeypatch): + """Workflow trace with message_id creates both message and workflow runs.""" + self._setup_repo(monkeypatch, nodes=[]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id="msg-1") + trace_instance.workflow_trace(trace_info) + + # message run + workflow run = 2 start_call / finish_call + assert trace_instance.start_call.call_count == 2 + assert trace_instance.finish_call.call_count == 2 + + def test_workflow_trace_with_node_execution(self, trace_instance, monkeypatch): + """Workflow trace iterates node executions and creates node runs.""" + node = _make_node( + id="node-1", + node_type=NodeType.CODE, + inputs={"k": "v"}, + outputs={"r": "ok"}, + elapsed_time=0.5, + created_at=_dt(), + metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 5}, + ) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # workflow run + node run = 2 calls + assert trace_instance.start_call.call_count == 2 + + def test_workflow_trace_with_llm_node(self, trace_instance, monkeypatch): + """LLM node uses process_data prompts as inputs.""" + node = _make_node( + node_type=NodeType.LLM, + process_data={ + "prompts": [{"role": "user", "content": "hi"}], + "model_mode": "chat", + "model_provider": "openai", + "model_name": "gpt-4", + }, + inputs={"key": "val"}, + ) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # Check node start_call was called with prompts input + node_call_args = trace_instance.start_call.call_args_list[-1] + node_run = node_call_args[0][0] + # WeaveTraceModel validator wraps list prompts into {"messages": [...]} + # The key "messages" should be present (validator transforms the list) + assert "messages" in node_run.inputs + + def test_workflow_trace_with_non_llm_node_uses_inputs(self, trace_instance, monkeypatch): + """Non-LLM node uses node_execution.inputs directly.""" + node = _make_node( + node_type=NodeType.TOOL, + inputs={"tool_input": "val"}, + process_data=None, + ) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # node run inputs should be from node.inputs; validator adds usage_metadata + file_list + node_call_args = trace_instance.start_call.call_args_list[-1] + node_run = node_call_args[0][0] + assert node_run.inputs.get("tool_input") == "val" + + def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch): + """Raises ValueError when app_id is missing from metadata.""" + monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + + trace_info = _make_workflow_trace_info( + message_id=None, + metadata={"user_id": "u1"}, # no app_id + ) + + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(trace_info) + + def test_workflow_trace_start_time_none_defaults_to_now(self, trace_instance, monkeypatch): + """start_time defaults to datetime.now() when None.""" + self._setup_repo(monkeypatch, nodes=[]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None, start_time=None) + trace_instance.workflow_trace(trace_info) + + assert trace_instance.start_call.call_count == 1 + + def test_workflow_trace_node_created_at_none(self, trace_instance, monkeypatch): + """Node with created_at=None uses datetime.now().""" + node = _make_node(created_at=None, elapsed_time=0.5) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + assert trace_instance.start_call.call_count == 2 + + def test_workflow_trace_chat_mode_llm_node_adds_provider(self, trace_instance, monkeypatch): + """Chat mode LLM node adds ls_provider and ls_model_name to attributes.""" + node = _make_node( + node_type=NodeType.LLM, + process_data={"model_mode": "chat", "model_provider": "openai", "model_name": "gpt-4", "prompts": []}, + ) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + start_calls = [] + + def capture_start(run, parent_run_id=None): + start_calls.append((run, parent_run_id)) + + trace_instance.start_call = capture_start + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # Last start call is the node run + node_run, _ = start_calls[-1] + assert node_run.attributes.get("ls_provider") == "openai" + assert node_run.attributes.get("ls_model_name") == "gpt-4" + + def test_workflow_trace_nodes_sorted_by_created_at(self, trace_instance, monkeypatch): + """Nodes are sorted by created_at before processing.""" + node1 = _make_node(id="node-b", created_at=_dt() + timedelta(seconds=2)) + node2 = _make_node(id="node-a", created_at=_dt()) + self._setup_repo(monkeypatch, nodes=[node1, node2]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + processed_ids = [] + + def capture_start(run, parent_run_id=None): + processed_ids.append(run.id) + + trace_instance.start_call = capture_start + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # First call = workflow run, then node-a, then node-b + assert processed_ids[1] == "node-a" + assert processed_ids[2] == "node-b" + + +# ── TestMessageTrace ────────────────────────────────────────────────────────── + + +class TestMessageTrace: + def test_returns_early_when_no_message_data(self, trace_instance): + """message_trace returns early when message_data is None.""" + trace_info = _make_message_trace_info(message_data=None) + trace_instance.start_call = MagicMock() + trace_instance.message_trace(trace_info) + trace_instance.start_call.assert_not_called() + + def test_basic_message_trace(self, trace_instance, monkeypatch): + """message_trace creates message run and llm child run.""" + monkeypatch.setattr( + "core.ops.weave_trace.weave_trace.db.session.query", + lambda model: MagicMock(where=lambda: MagicMock(first=lambda: None)), + ) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_message_trace_info() + trace_instance.message_trace(trace_info) + + # message run + llm child run + assert trace_instance.start_call.call_count == 2 + assert trace_instance.finish_call.call_count == 2 + + def test_message_trace_with_file_data(self, trace_instance, monkeypatch): + """message_trace appends file URL to file_list.""" + file_data = MagicMock() + file_data.url = "path/to/file.png" + trace_instance.file_base_url = "http://files.test" + + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_message_trace_info( + message_file_data=file_data, + file_list=["existing.txt"], + ) + trace_instance.message_trace(trace_info) + + # The first start_call arg (the message run) should have file in outputs or inputs + message_run = trace_instance.start_call.call_args_list[0][0][0] + assert "http://files.test/path/to/file.png" in message_run.file_list + + def test_message_trace_with_end_user(self, trace_instance, monkeypatch): + """message_trace looks up end user and sets end_user_id attribute.""" + end_user = MagicMock() + end_user.session_id = "session-xyz" + + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = end_user + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + msg_data = MagicMock() + msg_data.id = "msg-1" + msg_data.from_account_id = "acc-1" + msg_data.from_end_user_id = "eu-1" + + trace_info = _make_message_trace_info(message_data=msg_data) + trace_instance.message_trace(trace_info) + + message_run = trace_instance.start_call.call_args_list[0][0][0] + assert message_run.attributes.get("end_user_id") == "session-xyz" + + def test_message_trace_no_end_user(self, trace_instance, monkeypatch): + """message_trace handles when from_end_user_id is None.""" + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + msg_data = MagicMock() + msg_data.id = "msg-1" + msg_data.from_account_id = "acc-1" + msg_data.from_end_user_id = None + + trace_info = _make_message_trace_info(message_data=msg_data) + trace_instance.message_trace(trace_info) + assert trace_instance.start_call.call_count == 2 + + def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch): + """trace_id falls back to message_id when trace_id is None.""" + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_message_trace_info(trace_id=None) + trace_instance.message_trace(trace_info) + + message_run = trace_instance.start_call.call_args_list[0][0][0] + assert message_run.id == "msg-1" + + def test_message_trace_file_list_none(self, trace_instance, monkeypatch): + """message_trace handles file_list=None gracefully.""" + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_message_trace_info(file_list=None, message_file_data=None) + trace_instance.message_trace(trace_info) + assert trace_instance.start_call.call_count == 2 + + +# ── TestModerationTrace ─────────────────────────────────────────────────────── + + +class TestModerationTrace: + def test_returns_early_when_no_message_data(self, trace_instance): + """moderation_trace returns early when message_data is None.""" + trace_info = _make_moderation_trace_info(message_data=None) + trace_instance.start_call = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.start_call.assert_not_called() + + def test_basic_moderation_trace(self, trace_instance): + """moderation_trace creates a run with correct outputs.""" + msg_data = MagicMock() + msg_data.created_at = _dt() + msg_data.updated_at = _dt() + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_moderation_trace_info( + message_data=msg_data, + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + action="block", + flagged=True, + preset_response="blocked", + ) + trace_instance.moderation_trace(trace_info) + + trace_instance.start_call.assert_called_once() + trace_instance.finish_call.assert_called_once() + + run = trace_instance.start_call.call_args[0][0] + assert run.outputs["action"] == "block" + assert run.outputs["flagged"] is True + + def test_moderation_trace_with_no_times_uses_message_data_times(self, trace_instance): + """When start/end times are None, uses message_data created_at/updated_at.""" + msg_data = MagicMock() + msg_data.created_at = _dt() + msg_data.updated_at = _dt() + timedelta(seconds=1) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_moderation_trace_info( + message_data=msg_data, + start_time=None, + end_time=None, + ) + trace_instance.moderation_trace(trace_info) + trace_instance.start_call.assert_called_once() + + def test_moderation_trace_trace_id_fallback(self, trace_instance): + """trace_id falls back to message_id when trace_id is None.""" + msg_data = MagicMock() + msg_data.created_at = _dt() + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_moderation_trace_info( + message_data=msg_data, + trace_id=None, + ) + trace_instance.moderation_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "msg-1" + + +# ── TestSuggestedQuestionTrace ──────────────────────────────────────────────── + + +class TestSuggestedQuestionTrace: + def test_returns_early_when_no_message_data(self, trace_instance): + """suggested_question_trace returns early when message_data is None.""" + trace_info = _make_suggested_question_trace_info(message_data=None) + trace_instance.start_call = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.start_call.assert_not_called() + + def test_basic_suggested_question_trace(self, trace_instance): + """suggested_question_trace creates a run parented to trace_id.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_suggested_question_trace_info(trace_id="t-1") + trace_instance.suggested_question_trace(trace_info) + + trace_instance.start_call.assert_called_once() + trace_instance.finish_call.assert_called_once() + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "t-1" + + def test_suggested_question_trace_trace_id_fallback(self, trace_instance): + """trace_id falls back to message_id when trace_id is None.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_suggested_question_trace_info(trace_id=None) + trace_instance.suggested_question_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "msg-1" + + +# ── TestDatasetRetrievalTrace ───────────────────────────────────────────────── + + +class TestDatasetRetrievalTrace: + def test_returns_early_when_no_message_data(self, trace_instance): + """dataset_retrieval_trace returns early when message_data is None.""" + trace_info = _make_dataset_retrieval_trace_info(message_data=None) + trace_instance.start_call = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.start_call.assert_not_called() + + def test_basic_dataset_retrieval_trace(self, trace_instance): + """dataset_retrieval_trace creates a run with documents as outputs.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_dataset_retrieval_trace_info( + documents=[{"id": "d1"}, {"id": "d2"}], + trace_id="t-1", + ) + trace_instance.dataset_retrieval_trace(trace_info) + + run = trace_instance.start_call.call_args[0][0] + # WeaveTraceModel validator injects usage_metadata/file_list into dict outputs + assert run.outputs.get("documents") == [{"id": "d1"}, {"id": "d2"}] + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "t-1" + + def test_dataset_retrieval_trace_trace_id_fallback(self, trace_instance): + """trace_id falls back to message_id when trace_id is None.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_dataset_retrieval_trace_info(trace_id=None) + trace_instance.dataset_retrieval_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "msg-1" + + +# ── TestToolTrace ───────────────────────────────────────────────────────────── + + +class TestToolTrace: + def test_basic_tool_trace(self, trace_instance): + """tool_trace creates a run with correct op as tool_name.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(trace_id="t-1") + trace_instance.tool_trace(trace_info) + + run = trace_instance.start_call.call_args[0][0] + assert run.op == "my_tool" + # WeaveTraceModel validator injects usage_metadata/file_list into dict inputs + assert run.inputs.get("x") == 1 + + def test_tool_trace_with_file_url(self, trace_instance): + """tool_trace adds file_url to file_list when provided.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(file_url="http://files/file.pdf") + trace_instance.tool_trace(trace_info) + + run = trace_instance.start_call.call_args[0][0] + assert "http://files/file.pdf" in run.file_list + + def test_tool_trace_without_file_url(self, trace_instance): + """tool_trace uses empty file_list when file_url is None.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(file_url=None) + trace_instance.tool_trace(trace_info) + + run = trace_instance.start_call.call_args[0][0] + assert run.file_list == [] + + def test_tool_trace_trace_id_from_message_id(self, trace_instance): + """trace_id uses message_id fallback.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(trace_id=None) + trace_instance.tool_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "msg-1" + + def test_tool_trace_message_id_none_uses_conversation_id(self, trace_instance): + """When message_id is None, tries conversation_id attribute.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(trace_id=None, message_id=None) + trace_instance.tool_trace(trace_info) + + # No crash; parent_run_id is None since no fallback + _, kwargs = trace_instance.start_call.call_args + # parent_run_id should be None when no message_id and no trace_id + assert kwargs.get("parent_run_id") is None + + +# ── TestGenerateNameTrace ───────────────────────────────────────────────────── + + +class TestGenerateNameTrace: + def test_basic_generate_name_trace(self, trace_instance): + """generate_name_trace creates a run with correct op.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_generate_name_trace_info() + trace_instance.generate_name_trace(trace_info) + + trace_instance.start_call.assert_called_once() + trace_instance.finish_call.assert_called_once() + + run = trace_instance.start_call.call_args[0][0] + assert run.op == str(TraceTaskName.GENERATE_NAME_TRACE) + + def test_generate_name_trace_no_parent(self, trace_instance): + """generate_name_trace has no parent run (no parent_run_id).""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_generate_name_trace_info() + trace_instance.generate_name_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + # No parent_run_id passed to generate_name start_call + assert kwargs == {} or kwargs.get("parent_run_id") is None + + +# ── TestApiCheck ────────────────────────────────────────────────────────────── + + +class TestApiCheck: + def test_api_check_success_without_host(self, trace_instance, mock_wandb): + """api_check returns True on successful login without host.""" + trace_instance.host = None + mock_wandb.login.return_value = True + + result = trace_instance.api_check() + + assert result is True + mock_wandb.login.assert_called_with(key=trace_instance.weave_api_key, verify=True, relogin=True) + + def test_api_check_success_with_host(self, trace_instance, mock_wandb): + """api_check returns True on successful login with host.""" + trace_instance.host = "https://my.wandb.host" + mock_wandb.login.return_value = True + + result = trace_instance.api_check() + + assert result is True + mock_wandb.login.assert_called_with( + key=trace_instance.weave_api_key, verify=True, relogin=True, host="https://my.wandb.host" + ) + + def test_api_check_login_failure_raises(self, trace_instance, mock_wandb): + """api_check raises ValueError when login returns False.""" + trace_instance.host = None + mock_wandb.login.return_value = False + + with pytest.raises(ValueError, match="Weave API check failed"): + trace_instance.api_check() + + def test_api_check_exception_raises_value_error(self, trace_instance, mock_wandb): + """api_check raises ValueError when wandb.login raises exception.""" + trace_instance.host = None + mock_wandb.login.side_effect = Exception("network error") + + with pytest.raises(ValueError, match="Weave API check failed: network error"): + trace_instance.api_check() diff --git a/api/tests/unit_tests/core/plugin/impl/__init__.py b/api/tests/unit_tests/core/plugin/impl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/plugin/impl/test_agent_client.py b/api/tests/unit_tests/core/plugin/impl/test_agent_client.py new file mode 100644 index 0000000000..1537ffacf5 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_agent_client.py @@ -0,0 +1,91 @@ +from types import SimpleNamespace + +from core.plugin.entities.request import PluginInvokeContext +from core.plugin.impl.agent import PluginAgentClient + + +def _agent_provider(name: str = "agent") -> SimpleNamespace: + return SimpleNamespace( + plugin_id="org/plugin", + declaration=SimpleNamespace( + identity=SimpleNamespace(name=name), + strategies=[SimpleNamespace(identity=SimpleNamespace(provider=""))], + ), + ) + + +class TestPluginAgentClient: + def test_fetch_agent_strategy_providers(self, mocker): + client = PluginAgentClient() + provider = _agent_provider("remote") + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "declaration": { + "identity": {"name": "remote"}, + "strategies": [{"identity": {"provider": "old"}}], + } + } + ] + } + transformed = transformer(payload) + assert transformed["data"][0]["declaration"]["strategies"][0]["identity"]["provider"] == "remote" + return [provider] + + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = client.fetch_agent_strategy_providers("tenant-1") + + assert request_mock.call_count == 1 + assert result[0].declaration.identity.name == "org/plugin/remote" + assert result[0].declaration.strategies[0].identity.provider == "org/plugin/remote" + + def test_fetch_agent_strategy_provider(self, mocker): + client = PluginAgentClient() + provider = _agent_provider("provider") + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + assert transformer({"data": None}) == {"data": None} + payload = {"data": {"declaration": {"strategies": [{"identity": {"provider": "old"}}]}}} + transformed = transformer(payload) + assert transformed["data"]["declaration"]["strategies"][0]["identity"]["provider"] == "provider" + return provider + + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = client.fetch_agent_strategy_provider("tenant-1", "org/plugin/provider") + + assert request_mock.call_count == 1 + assert result.declaration.identity.name == "org/plugin/provider" + assert result.declaration.strategies[0].identity.provider == "org/plugin/provider" + + def test_invoke_merges_chunks_and_passes_context(self, mocker): + client = PluginAgentClient() + stream_mock = mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter(["raw"]) + ) + merge_mock = mocker.patch("core.plugin.impl.agent.merge_blob_chunks", return_value=["merged"]) + context = PluginInvokeContext() + + result = client.invoke( + tenant_id="tenant-1", + user_id="user-1", + agent_provider="org/plugin/provider", + agent_strategy="router", + agent_params={"k": "v"}, + conversation_id="conv-1", + app_id="app-1", + message_id="msg-1", + context=context, + ) + + assert result == ["merged"] + assert merge_mock.call_count == 1 + payload = stream_mock.call_args.kwargs["data"] + assert payload["data"]["agent_strategy_provider"] == "provider" + assert payload["context"] == context.model_dump() + assert stream_mock.call_args.kwargs["headers"]["X-Plugin-ID"] == "org/plugin" diff --git a/api/tests/unit_tests/core/plugin/impl/test_asset_manager.py b/api/tests/unit_tests/core/plugin/impl/test_asset_manager.py new file mode 100644 index 0000000000..5f564062d5 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_asset_manager.py @@ -0,0 +1,45 @@ +from unittest.mock import MagicMock + +import pytest + +from core.plugin.impl.asset import PluginAssetManager + + +class TestPluginAssetManager: + def test_fetch_asset_success(self, mocker): + manager = PluginAssetManager() + response = MagicMock(status_code=200, content=b"asset-bytes") + request_mock = mocker.patch.object(manager, "_request", return_value=response) + + result = manager.fetch_asset("tenant-1", "asset-1") + + assert result == b"asset-bytes" + request_mock.assert_called_once_with(method="GET", path="plugin/tenant-1/asset/asset-1") + + def test_fetch_asset_not_found_raises(self, mocker): + manager = PluginAssetManager() + mocker.patch.object(manager, "_request", return_value=MagicMock(status_code=404, content=b"")) + + with pytest.raises(ValueError, match="can not found asset asset-1"): + manager.fetch_asset("tenant-1", "asset-1") + + def test_extract_asset_success(self, mocker): + manager = PluginAssetManager() + response = MagicMock(status_code=200, content=b"file-content") + request_mock = mocker.patch.object(manager, "_request", return_value=response) + + result = manager.extract_asset("tenant-1", "org/plugin:1", "README.md") + + assert result == b"file-content" + request_mock.assert_called_once_with( + method="GET", + path="plugin/tenant-1/extract-asset/", + params={"plugin_unique_identifier": "org/plugin:1", "file_path": "README.md"}, + ) + + def test_extract_asset_not_found_raises(self, mocker): + manager = PluginAssetManager() + mocker.patch.object(manager, "_request", return_value=MagicMock(status_code=404, content=b"")) + + with pytest.raises(ValueError, match="can not found asset org/plugin:1, 404"): + manager.extract_asset("tenant-1", "org/plugin:1", "README.md") diff --git a/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py b/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py new file mode 100644 index 0000000000..c216906d68 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py @@ -0,0 +1,137 @@ +import json + +import pytest + +from core.plugin.endpoint.exc import EndpointSetupFailedError +from core.plugin.entities.plugin_daemon import PluginDaemonInnerError +from core.plugin.impl.base import BasePluginClient +from core.trigger.errors import ( + EventIgnoreError, + TriggerInvokeError, + TriggerPluginInvokeError, + TriggerProviderCredentialValidationError, +) + + +class _ResponseStub: + def __init__(self, payload): + self._payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self._payload + + +class _StreamContext: + def __init__(self, lines): + self._lines = lines + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def iter_lines(self): + return self._lines + + +class TestBasePluginClientImpl: + def test_inject_trace_headers(self, mocker): + client = BasePluginClient() + mocker.patch("core.plugin.impl.base.dify_config.ENABLE_OTEL", True) + trace_header = "00-abc-xyz-01" + mocker.patch("core.helper.trace_id_helper.generate_traceparent_header", return_value=trace_header) + + headers = {} + client._inject_trace_headers(headers) + + assert headers["traceparent"] == trace_header + + headers_with_existing = {"TraceParent": "exists"} + client._inject_trace_headers(headers_with_existing) + assert headers_with_existing["TraceParent"] == "exists" + + def test_stream_request_handles_data_lines_and_dict_payload(self, mocker): + client = BasePluginClient() + stream_mock = mocker.patch( + "core.plugin.impl.base.httpx.stream", + return_value=_StreamContext([b"", b"data: hello", "world"]), + ) + + result = list(client._stream_request("POST", "plugin/tenant/stream", data={"k": "v"})) + + assert result == ["hello", "world"] + assert stream_mock.call_args.kwargs["data"] == {"k": "v"} + + def test_request_with_plugin_daemon_response_handles_request_exception(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_request", side_effect=RuntimeError("boom")) + + with pytest.raises(ValueError, match="Failed to request plugin daemon"): + client._request_with_plugin_daemon_response("GET", "plugin/tenant/path", bool) + + def test_request_with_plugin_daemon_response_applies_transformer(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_request", return_value=_ResponseStub({"code": 0, "message": "", "data": True})) + + transformed = {} + + def transformer(payload): + transformed.update(payload) + return payload + + result = client._request_with_plugin_daemon_response("GET", "plugin/tenant/path", bool, transformer=transformer) + + assert result is True + assert transformed == {"code": 0, "message": "", "data": True} + + def test_request_with_plugin_daemon_response_stream_malformed_json_error(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_stream_request", return_value=iter(['{"error":"bad-line"}'])) + + with pytest.raises(ValueError, match="bad-line"): + list(client._request_with_plugin_daemon_response_stream("GET", "p", bool)) + + def test_request_with_plugin_daemon_response_stream_plugin_daemon_inner_error(self, mocker): + client = BasePluginClient() + mocker.patch.object( + client, "_stream_request", return_value=iter(['{"code":-500,"message":"not-json","data":null}']) + ) + + with pytest.raises(PluginDaemonInnerError) as exc_info: + list(client._request_with_plugin_daemon_response_stream("GET", "p", bool)) + assert exc_info.value.message == "not-json" + + def test_request_with_plugin_daemon_response_stream_plugin_daemon_error(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_stream_request", return_value=iter(['{"code":-1,"message":"err","data":null}'])) + + with pytest.raises(ValueError, match="plugin daemon: err, code: -1"): + list(client._request_with_plugin_daemon_response_stream("GET", "p", bool)) + + def test_request_with_plugin_daemon_response_stream_empty_data_error(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_stream_request", return_value=iter(['{"code":0,"message":"","data":null}'])) + + with pytest.raises(ValueError, match="got empty data"): + list(client._request_with_plugin_daemon_response_stream("GET", "p", bool)) + + @pytest.mark.parametrize( + ("error_type", "expected"), + [ + (EndpointSetupFailedError.__name__, EndpointSetupFailedError), + (TriggerProviderCredentialValidationError.__name__, TriggerProviderCredentialValidationError), + (TriggerPluginInvokeError.__name__, TriggerPluginInvokeError), + (TriggerInvokeError.__name__, TriggerInvokeError), + (EventIgnoreError.__name__, EventIgnoreError), + ], + ) + def test_handle_plugin_daemon_error_trigger_branches(self, error_type, expected): + client = BasePluginClient() + message = json.dumps({"error_type": error_type, "message": "m"}) + + with pytest.raises(expected): + client._handle_plugin_daemon_error("PluginInvokeError", message) diff --git a/api/tests/unit_tests/core/plugin/impl/test_datasource_manager.py b/api/tests/unit_tests/core/plugin/impl/test_datasource_manager.py new file mode 100644 index 0000000000..4c5987d759 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_datasource_manager.py @@ -0,0 +1,234 @@ +from types import SimpleNamespace + +from core.datasource.entities.datasource_entities import ( + GetOnlineDocumentPageContentRequest, + OnlineDriveBrowseFilesRequest, + OnlineDriveDownloadFileRequest, +) +from core.plugin.impl.datasource import PluginDatasourceManager + + +def _datasource_provider(name: str = "provider") -> SimpleNamespace: + return SimpleNamespace( + plugin_id="org/plugin", + declaration=SimpleNamespace( + identity=SimpleNamespace(name=name), + datasources=[SimpleNamespace(identity=SimpleNamespace(provider=""))], + ), + ) + + +class TestPluginDatasourceManager: + def test_fetch_datasource_providers(self, mocker): + manager = PluginDatasourceManager() + provider = _datasource_provider("remote") + repack = mocker.patch("core.plugin.impl.datasource.ToolTransformService.repack_provider") + mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "declaration": { + "identity": {"name": "remote"}, + "datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/doc"}}], + } + } + ] + } + transformed = transformer(payload) + assert transformed["data"][0]["declaration"]["datasources"][0]["output_schema"] == {"resolved": True} + return [provider] + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_datasource_providers("tenant-1") + + assert request_mock.call_count == 1 + assert len(result) == 2 + assert result[0].plugin_id == "langgenius/file" + assert result[1].declaration.identity.name == "org/plugin/remote" + assert result[1].declaration.datasources[0].identity.provider == "org/plugin/remote" + repack.assert_called_once_with(tenant_id="tenant-1", provider=provider) + + def test_fetch_installed_datasource_providers(self, mocker): + manager = PluginDatasourceManager() + provider = _datasource_provider("remote") + repack = mocker.patch("core.plugin.impl.datasource.ToolTransformService.repack_provider") + mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "declaration": { + "identity": {"name": "remote"}, + "datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/doc"}}], + } + } + ] + } + transformer(payload) + return [provider] + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_installed_datasource_providers("tenant-1") + + assert request_mock.call_count == 1 + assert len(result) == 1 + assert result[0].declaration.identity.name == "org/plugin/remote" + assert result[0].declaration.datasources[0].identity.provider == "org/plugin/remote" + repack.assert_called_once_with(tenant_id="tenant-1", provider=provider) + + def test_fetch_datasource_provider_local_and_remote(self, mocker): + manager = PluginDatasourceManager() + + local = manager.fetch_datasource_provider("tenant-1", "langgenius/file/file") + assert local.plugin_id == "langgenius/file" + + remote = _datasource_provider("provider") + mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": { + "declaration": { + "datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}] + } + } + } + transformed = transformer(payload) + assert transformed["data"]["declaration"]["datasources"][0]["output_schema"] == {"resolved": True} + return remote + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_datasource_provider("tenant-1", "org/plugin/provider") + + assert request_mock.call_count == 1 + assert result.declaration.identity.name == "org/plugin/provider" + assert result.declaration.datasources[0].identity.provider == "org/plugin/provider" + + def test_get_website_crawl_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["crawl"]) + + assert list( + manager.get_website_crawl( + "tenant-1", + "user-1", + "org/plugin/provider", + "crawl", + {"k": "v"}, + {"url": "https://example.com"}, + "website", + ) + ) == ["crawl"] + + assert stream_mock.call_count == 1 + + def test_get_online_document_pages_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["pages"]) + + assert list( + manager.get_online_document_pages( + "tenant-1", + "user-1", + "org/plugin/provider", + "docs", + {"k": "v"}, + {"workspace": "w1"}, + "online_document", + ) + ) == ["pages"] + + assert stream_mock.call_count == 1 + + def test_get_online_document_page_content_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["content"]) + + assert list( + manager.get_online_document_page_content( + "tenant-1", + "user-1", + "org/plugin/provider", + "docs", + {"k": "v"}, + GetOnlineDocumentPageContentRequest(workspace_id="w", page_id="p", type="doc"), + "online_document", + ) + ) == ["content"] + + assert stream_mock.call_count == 1 + + def test_online_drive_browse_files_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["browse"]) + + assert list( + manager.online_drive_browse_files( + "tenant-1", + "user-1", + "org/plugin/provider", + "drive", + {"k": "v"}, + OnlineDriveBrowseFilesRequest(prefix="/"), + "online_drive", + ) + ) == ["browse"] + + assert stream_mock.call_count == 1 + + def test_online_drive_download_file_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["download"]) + + assert list( + manager.online_drive_download_file( + "tenant-1", + "user-1", + "org/plugin/provider", + "drive", + {"k": "v"}, + OnlineDriveDownloadFileRequest(id="file-1"), + "online_drive", + ) + ) == ["download"] + + assert stream_mock.call_count == 1 + + def test_validate_provider_credentials_returns_true_when_stream_yields_result(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter([SimpleNamespace(result=True)]) + + assert manager.validate_provider_credentials("tenant-1", "user-1", "provider", "org/plugin", {"k": "v"}) is True + + def test_validate_provider_credentials_returns_false_when_stream_empty(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter([]) + + assert ( + manager.validate_provider_credentials("tenant-1", "user-1", "provider", "org/plugin", {"k": "v"}) is False + ) + + def test_local_file_provider_template(self): + manager = PluginDatasourceManager() + + payload = manager._get_local_file_datasource_provider() + + assert payload["plugin_id"] == "langgenius/file" + assert payload["provider"] == "file" + assert payload["declaration"]["provider_type"] == "local_file" diff --git a/api/tests/unit_tests/core/plugin/impl/test_debugging_client.py b/api/tests/unit_tests/core/plugin/impl/test_debugging_client.py new file mode 100644 index 0000000000..c80785aee0 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_debugging_client.py @@ -0,0 +1,21 @@ +from types import SimpleNamespace + +from core.plugin.impl.debugging import PluginDebuggingClient + + +class TestPluginDebuggingClient: + def test_get_debugging_key(self, mocker): + client = PluginDebuggingClient() + request_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response", + return_value=SimpleNamespace(key="debug-key"), + ) + + result = client.get_debugging_key("tenant-1") + + assert result == "debug-key" + request_mock.assert_called_once() + args = request_mock.call_args.args + assert args[0] == "POST" + assert args[1] == "plugin/tenant-1/debugging/key" diff --git a/api/tests/unit_tests/core/plugin/impl/test_endpoint_client_impl.py b/api/tests/unit_tests/core/plugin/impl/test_endpoint_client_impl.py new file mode 100644 index 0000000000..4cf657a050 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_endpoint_client_impl.py @@ -0,0 +1,71 @@ +import pytest + +from core.plugin.impl.endpoint import PluginEndpointClient +from core.plugin.impl.exc import PluginDaemonInternalServerError + + +class TestPluginEndpointClientImpl: + def test_create_endpoint(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True) + + result = client.create_endpoint("tenant-1", "user-1", "org/plugin:1", "endpoint-a", {"k": "v"}) + + assert result is True + assert request_mock.call_count == 1 + args = request_mock.call_args.args + kwargs = request_mock.call_args.kwargs + assert args[:3] == ("POST", "plugin/tenant-1/endpoint/setup", bool) + assert kwargs["data"]["plugin_unique_identifier"] == "org/plugin:1" + + def test_list_endpoints(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["endpoint"]) + + result = client.list_endpoints("tenant-1", "user-1", 2, 20) + + assert result == ["endpoint"] + assert request_mock.call_args.args[1] == "plugin/tenant-1/endpoint/list" + assert request_mock.call_args.kwargs["params"] == {"page": 2, "page_size": 20} + + def test_list_endpoints_for_single_plugin(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["endpoint"]) + + result = client.list_endpoints_for_single_plugin("tenant-1", "user-1", "org/plugin", 1, 10) + + assert result == ["endpoint"] + assert request_mock.call_args.args[1] == "plugin/tenant-1/endpoint/list/plugin" + assert request_mock.call_args.kwargs["params"] == {"plugin_id": "org/plugin", "page": 1, "page_size": 10} + + def test_update_endpoint(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True) + + result = client.update_endpoint("tenant-1", "user-1", "endpoint-1", "renamed", {"x": 1}) + + assert result is True + assert request_mock.call_args.args[:3] == ("POST", "plugin/tenant-1/endpoint/update", bool) + + def test_enable_and_disable_endpoint(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True) + + assert client.enable_endpoint("tenant-1", "user-1", "endpoint-1") is True + assert client.disable_endpoint("tenant-1", "user-1", "endpoint-1") is True + + calls = request_mock.call_args_list + assert calls[0].args[1] == "plugin/tenant-1/endpoint/enable" + assert calls[1].args[1] == "plugin/tenant-1/endpoint/disable" + + def test_delete_endpoint_idempotent_and_re_raise(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response") + + request_mock.side_effect = PluginDaemonInternalServerError("record not found") + assert client.delete_endpoint("tenant-1", "user-1", "endpoint-1") is True + + request_mock.side_effect = PluginDaemonInternalServerError("permission denied") + with pytest.raises(PluginDaemonInternalServerError) as exc_info: + client.delete_endpoint("tenant-1", "user-1", "endpoint-1") + assert "permission denied" in exc_info.value.description diff --git a/api/tests/unit_tests/core/plugin/impl/test_exc_impl.py b/api/tests/unit_tests/core/plugin/impl/test_exc_impl.py new file mode 100644 index 0000000000..8c6f1c6b7f --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_exc_impl.py @@ -0,0 +1,41 @@ +import json + +from core.plugin.impl import exc as exc_module +from core.plugin.impl.exc import PluginDaemonError, PluginInvokeError + + +class TestPluginImplExceptions: + def test_plugin_daemon_error_str_contains_request_id(self, mocker): + mocker.patch("core.plugin.impl.exc.get_request_id", return_value="req-123") + error = PluginDaemonError("bad") + + assert str(error) == "req_id: req-123 PluginDaemonError: bad" + + def test_plugin_invoke_error_with_json_payload(self): + err = PluginInvokeError(json.dumps({"error_type": "RateLimit", "message": "too many"})) + + assert err.get_error_type() == "RateLimit" + assert err.get_error_message() == "too many" + friendly = err.to_user_friendly_error("test-plugin") + assert "test-plugin" in friendly + assert "RateLimit" in friendly + assert "too many" in friendly + + def test_plugin_invoke_error_invalid_json_and_fallback(self, mocker): + err = PluginInvokeError("plain text") + + assert err._get_error_object() == {} + assert err.get_error_type() == "unknown" + assert err.get_error_message() == "unknown" + + mocker.patch.object(PluginInvokeError, "_get_error_object", side_effect=RuntimeError("boom")) + err2 = PluginInvokeError("plain text") + assert err2.get_error_message() == "plain text" + + def test_plugin_invoke_error_get_error_object_handles_adapter_exception(self, mocker): + adapter = mocker.patch.object(exc_module, "TypeAdapter") + adapter.return_value.validate_json.side_effect = RuntimeError("invalid") + + err = PluginInvokeError("not-json") + + assert err._get_error_object() == {} diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_client.py b/api/tests/unit_tests/core/plugin/impl/test_model_client.py new file mode 100644 index 0000000000..bcbebbb38b --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_model_client.py @@ -0,0 +1,490 @@ +from __future__ import annotations + +import io +from types import SimpleNamespace + +import pytest + +from core.plugin.entities.plugin_daemon import PluginDaemonInnerError +from core.plugin.impl.model import PluginModelClient + + +class TestPluginModelClient: + def test_fetch_model_providers(self, mocker): + client = PluginModelClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["provider-a"]) + + result = client.fetch_model_providers("tenant-1") + + assert result == ["provider-a"] + assert request_mock.call_args.args[:2] == ( + "GET", + "plugin/tenant-1/management/models", + ) + assert request_mock.call_args.kwargs["params"] == {"page": 1, "page_size": 256} + + def test_get_model_schema(self, mocker): + client = PluginModelClient() + schema = SimpleNamespace(name="schema") + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(model_schema=schema)]), + ) + + result = client.get_model_schema( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model_type="llm", + model="gpt-test", + credentials={"api_key": "key"}, + ) + + assert result is schema + assert stream_mock.call_args.args[:2] == ("POST", "plugin/tenant-1/dispatch/model/schema") + + def test_get_model_schema_empty_stream_returns_none(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + result = client.get_model_schema("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {}) + + assert result is None + + def test_validate_provider_credentials(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result=True, credentials={"api_key": "new"})]), + ) + credentials = {"api_key": "old"} + + result = client.validate_provider_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + credentials=credentials, + ) + + assert result is True + assert credentials["api_key"] == "new" + assert stream_mock.call_args.args[:2] == ( + "POST", + "plugin/tenant-1/dispatch/model/validate_provider_credentials", + ) + + def test_validate_provider_credentials_without_dict_update(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result=False, credentials="not-a-dict")]), + ) + credentials = {"api_key": "same"} + + result = client.validate_provider_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", credentials) + + assert result is False + assert credentials == {"api_key": "same"} + + def test_validate_provider_credentials_empty_returns_false(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert client.validate_provider_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", {}) is False + + def test_validate_model_credentials(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result=True, credentials={"token": "rotated"})]), + ) + credentials = {"token": "old"} + + result = client.validate_model_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model_type="llm", + model="gpt-test", + credentials=credentials, + ) + + assert result is True + assert credentials["token"] == "rotated" + assert stream_mock.call_args.args[:2] == ( + "POST", + "plugin/tenant-1/dispatch/model/validate_model_credentials", + ) + + def test_validate_model_credentials_empty_returns_false(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert ( + client.validate_model_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {}) + is False + ) + + def test_invoke_llm(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk-1"]) + ) + + result = list( + client.invoke_llm( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="gpt-test", + credentials={"api_key": "key"}, + prompt_messages=[], + model_parameters={"temperature": 0.1}, + tools=[], + stop=["STOP"], + stream=False, + ) + ) + + assert result == ["chunk-1"] + call_kwargs = stream_mock.call_args.kwargs + assert call_kwargs["path"] == "plugin/tenant-1/dispatch/llm/invoke" + assert call_kwargs["data"]["data"]["stream"] is False + assert call_kwargs["data"]["data"]["model_parameters"] == {"temperature": 0.1} + + def test_invoke_llm_wraps_plugin_daemon_inner_error(self, mocker): + client = PluginModelClient() + + def _boom(): + raise PluginDaemonInnerError(code=-500, message="invoke failed") + yield # pragma: no cover + + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=_boom()) + + with pytest.raises(ValueError, match="invoke failed-500"): + list( + client.invoke_llm( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="gpt-test", + credentials={}, + prompt_messages=[], + ) + ) + + def test_get_llm_num_tokens(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(num_tokens=42)]), + ) + + result = client.get_llm_num_tokens( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model_type="llm", + model="gpt-test", + credentials={}, + prompt_messages=[], + tools=[], + ) + + assert result == 42 + + def test_get_llm_num_tokens_empty_returns_zero(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert ( + client.get_llm_num_tokens("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {}, []) + == 0 + ) + + def test_invoke_text_embedding(self, mocker): + client = PluginModelClient() + embedding_result = SimpleNamespace(data=[[0.1, 0.2]]) + mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter([embedding_result]) + ) + + result = client.invoke_text_embedding( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="embedding-a", + credentials={}, + texts=["hello"], + input_type="search_document", + ) + + assert result is embedding_result + + def test_invoke_text_embedding_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke text embedding"): + client.invoke_text_embedding( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["hello"], "x" + ) + + def test_invoke_multimodal_embedding(self, mocker): + client = PluginModelClient() + embedding_result = SimpleNamespace(data=[[0.3, 0.4]]) + mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter([embedding_result]) + ) + + result = client.invoke_multimodal_embedding( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="embedding-a", + credentials={}, + documents=[{"type": "image", "value": "abc"}], + input_type="search_document", + ) + + assert result is embedding_result + + def test_invoke_multimodal_embedding_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke file embedding"): + client.invoke_multimodal_embedding( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, [{"type": "image"}], "x" + ) + + def test_get_text_embedding_num_tokens(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(num_tokens=[1, 2, 3])]), + ) + + assert client.get_text_embedding_num_tokens( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["a"] + ) == [ + 1, + 2, + 3, + ] + + def test_get_text_embedding_num_tokens_empty_returns_list(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert ( + client.get_text_embedding_num_tokens( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["a"] + ) + == [] + ) + + def test_invoke_rerank(self, mocker): + client = PluginModelClient() + rerank_result = SimpleNamespace(scores=[0.9]) + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([rerank_result])) + + result = client.invoke_rerank( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="rerank-a", + credentials={}, + query="q", + docs=["doc-1"], + score_threshold=0.2, + top_n=5, + ) + + assert result is rerank_result + + def test_invoke_rerank_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke rerank"): + client.invoke_rerank("tenant-1", "user-1", "org/plugin:1", "provider-a", "rerank-a", {}, "q", ["doc-1"]) + + def test_invoke_multimodal_rerank(self, mocker): + client = PluginModelClient() + rerank_result = SimpleNamespace(scores=[0.8]) + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([rerank_result])) + + result = client.invoke_multimodal_rerank( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="rerank-a", + credentials={}, + query={"type": "text", "value": "q"}, + docs=[{"type": "image", "value": "doc"}], + score_threshold=0.1, + top_n=3, + ) + + assert result is rerank_result + + def test_invoke_multimodal_rerank_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke multimodal rerank"): + client.invoke_multimodal_rerank( + "tenant-1", + "user-1", + "org/plugin:1", + "provider-a", + "rerank-a", + {}, + {"type": "text"}, + [{"type": "image"}], + ) + + def test_invoke_tts(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result="68656c6c6f"), SimpleNamespace(result="21")]), + ) + + result = list( + client.invoke_tts( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="tts-a", + credentials={}, + content_text="hello", + voice="alloy", + ) + ) + + assert result == [b"hello", b"!"] + + def test_invoke_tts_wraps_plugin_daemon_inner_error(self, mocker): + client = PluginModelClient() + + def _boom(): + raise PluginDaemonInnerError(code=-400, message="tts error") + yield # pragma: no cover + + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=_boom()) + + with pytest.raises(ValueError, match="tts error-400"): + list(client.invoke_tts("tenant-1", "user-1", "org/plugin:1", "provider-a", "tts-a", {}, "hello", "alloy")) + + def test_get_tts_model_voices(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter( + [ + SimpleNamespace( + voices=[ + SimpleNamespace(name="Alloy", value="alloy"), + SimpleNamespace(name="Echo", value="echo"), + ] + ) + ] + ), + ) + + result = client.get_tts_model_voices( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="tts-a", + credentials={}, + language="en", + ) + + assert result == [{"name": "Alloy", "value": "alloy"}, {"name": "Echo", "value": "echo"}] + + def test_get_tts_model_voices_empty_returns_list(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert client.get_tts_model_voices("tenant-1", "user-1", "org/plugin:1", "provider-a", "tts-a", {}) == [] + + def test_invoke_speech_to_text(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result="transcribed text")]), + ) + + result = client.invoke_speech_to_text( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="stt-a", + credentials={}, + file=io.BytesIO(b"abc"), + ) + + assert result == "transcribed text" + assert stream_mock.call_args.kwargs["data"]["data"]["file"] == "616263" + + def test_invoke_speech_to_text_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke speech to text"): + client.invoke_speech_to_text( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "stt-a", {}, io.BytesIO(b"abc") + ) + + def test_invoke_moderation(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result=True)]), + ) + + result = client.invoke_moderation( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="moderation-a", + credentials={}, + text="safe text", + ) + + assert result is True + assert stream_mock.call_args.kwargs["path"] == "plugin/tenant-1/dispatch/moderation/invoke" + + def test_invoke_moderation_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke moderation"): + client.invoke_moderation("tenant-1", "user-1", "org/plugin:1", "provider-a", "moderation-a", {}, "unsafe") diff --git a/api/tests/unit_tests/core/plugin/impl/test_oauth_handler.py b/api/tests/unit_tests/core/plugin/impl/test_oauth_handler.py new file mode 100644 index 0000000000..6fb4c99432 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_oauth_handler.py @@ -0,0 +1,147 @@ +from io import BytesIO +from types import SimpleNamespace + +import pytest +from werkzeug import Request + +from core.plugin.impl.oauth import OAuthHandler + + +def _build_request(body: bytes = b"payload") -> Request: + environ = { + "REQUEST_METHOD": "POST", + "PATH_INFO": "/oauth/callback", + "QUERY_STRING": "code=123", + "SERVER_NAME": "localhost", + "SERVER_PORT": "80", + "wsgi.input": BytesIO(body), + "wsgi.url_scheme": "http", + "CONTENT_LENGTH": str(len(body)), + "HTTP_HOST": "localhost", + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_X_TEST": "yes", + } + return Request(environ) + + +class TestOAuthHandler: + def test_get_authorization_url(self, mocker): + handler = OAuthHandler() + stream_mock = mocker.patch.object( + handler, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(authorization_url="https://auth.example.com")]), + ) + + response = handler.get_authorization_url( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={"client_id": "id"}, + ) + + assert response.authorization_url == "https://auth.example.com" + assert stream_mock.call_count == 1 + + def test_get_authorization_url_no_response_raises(self, mocker): + handler = OAuthHandler() + mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Error getting authorization URL"): + handler.get_authorization_url( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={}, + ) + + def test_get_credentials(self, mocker): + handler = OAuthHandler() + captured_data = {} + + def fake_stream(*args, **kwargs): + captured_data.update(kwargs["data"]) + return iter([SimpleNamespace(credentials={"token": "abc"}, metadata={}, expires_at=1)]) + + stream_mock = mocker.patch.object( + handler, "_request_with_plugin_daemon_response_stream", side_effect=fake_stream + ) + + response = handler.get_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={"client_id": "id"}, + request=_build_request(), + ) + + assert response.credentials == {"token": "abc"} + assert "raw_http_request" in captured_data["data"] + assert stream_mock.call_count == 1 + + def test_get_credentials_no_response_raises(self, mocker): + handler = OAuthHandler() + mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Error getting credentials"): + handler.get_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={}, + request=_build_request(), + ) + + def test_refresh_credentials(self, mocker): + handler = OAuthHandler() + stream_mock = mocker.patch.object( + handler, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(credentials={"token": "new"}, metadata={}, expires_at=1)]), + ) + + response = handler.refresh_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={"client_id": "id"}, + credentials={"refresh_token": "r"}, + ) + + assert response.credentials == {"token": "new"} + assert stream_mock.call_count == 1 + + def test_refresh_credentials_no_response_raises(self, mocker): + handler = OAuthHandler() + mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Error refreshing credentials"): + handler.refresh_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={}, + credentials={}, + ) + + def test_convert_request_to_raw_data(self): + handler = OAuthHandler() + request = _build_request(b"body-data") + + raw = handler._convert_request_to_raw_data(request) + + assert raw.startswith(b"POST /oauth/callback?code=123 HTTP/1.1\r\n") + assert b"X-Test: yes\r\n" in raw + assert raw.endswith(b"body-data") diff --git a/api/tests/unit_tests/core/plugin/impl/test_tool_manager.py b/api/tests/unit_tests/core/plugin/impl/test_tool_manager.py new file mode 100644 index 0000000000..80cf46f9bb --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_tool_manager.py @@ -0,0 +1,121 @@ +from types import SimpleNamespace + +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.impl.tool import PluginToolManager + + +def _tool_provider(name: str = "provider") -> SimpleNamespace: + return SimpleNamespace( + plugin_id="org/plugin", + declaration=SimpleNamespace( + identity=SimpleNamespace(name=name), + tools=[SimpleNamespace(identity=SimpleNamespace(provider=""))], + ), + ) + + +class TestPluginToolManager: + def test_fetch_tool_providers(self, mocker): + manager = PluginToolManager() + provider = _tool_provider("remote") + mocker.patch("core.plugin.impl.tool.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "declaration": { + "identity": {"name": "remote"}, + "tools": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}], + } + } + ] + } + transformed = transformer(payload) + assert transformed["data"][0]["declaration"]["tools"][0]["output_schema"] == {"resolved": True} + return [provider] + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_tool_providers("tenant-1") + + assert request_mock.call_count == 1 + assert result[0].declaration.identity.name == "org/plugin/remote" + assert result[0].declaration.tools[0].identity.provider == "org/plugin/remote" + + def test_fetch_tool_provider(self, mocker): + manager = PluginToolManager() + provider = _tool_provider("provider") + mocker.patch("core.plugin.impl.tool.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": { + "declaration": {"tools": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}]} + } + } + transformed = transformer(payload) + assert transformed["data"]["declaration"]["tools"][0]["output_schema"] == {"resolved": True} + return provider + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_tool_provider("tenant-1", "org/plugin/provider") + + assert request_mock.call_count == 1 + assert result.declaration.identity.name == "org/plugin/provider" + assert result.declaration.tools[0].identity.provider == "org/plugin/provider" + + def test_invoke_merges_chunks(self, mocker): + manager = PluginToolManager() + stream_mock = mocker.patch.object( + manager, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk"]) + ) + merge_mock = mocker.patch("core.plugin.impl.tool.merge_blob_chunks", return_value=["merged"]) + + result = manager.invoke( + tenant_id="tenant-1", + user_id="user-1", + tool_provider="org/plugin/provider", + tool_name="search", + credentials={"api_key": "k"}, + credential_type=CredentialType.API_KEY, + tool_parameters={"q": "python"}, + conversation_id="conv-1", + app_id="app-1", + message_id="msg-1", + ) + + assert result == ["merged"] + assert merge_mock.call_count == 1 + assert stream_mock.call_args.kwargs["headers"]["X-Plugin-ID"] == "org/plugin" + + def test_validate_credentials_paths(self, mocker): + manager = PluginToolManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + + stream_mock.return_value = iter([SimpleNamespace(result=True)]) + assert manager.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True + + stream_mock.return_value = iter([]) + assert manager.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is False + + stream_mock.return_value = iter([SimpleNamespace(result=True)]) + assert manager.validate_datasource_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True + + stream_mock.return_value = iter([]) + assert manager.validate_datasource_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is False + + def test_get_runtime_parameters_paths(self, mocker): + manager = PluginToolManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + + stream_mock.return_value = iter([SimpleNamespace(parameters=[{"name": "p"}])]) + params = manager.get_runtime_parameters("tenant-1", "user-1", "org/plugin/provider", {}, "search") + assert params == [{"name": "p"}] + + stream_mock.return_value = iter([]) + params = manager.get_runtime_parameters("tenant-1", "user-1", "org/plugin/provider", {}, "search") + assert params == [] diff --git a/api/tests/unit_tests/core/plugin/impl/test_trigger_client.py b/api/tests/unit_tests/core/plugin/impl/test_trigger_client.py new file mode 100644 index 0000000000..76da51c2c8 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_trigger_client.py @@ -0,0 +1,226 @@ +from io import BytesIO +from types import SimpleNamespace + +import pytest +from werkzeug import Request + +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.impl.trigger import PluginTriggerClient +from core.trigger.entities.entities import Subscription +from models.provider_ids import TriggerProviderID + + +def _request() -> Request: + environ = { + "REQUEST_METHOD": "POST", + "PATH_INFO": "/events", + "QUERY_STRING": "", + "SERVER_NAME": "localhost", + "SERVER_PORT": "80", + "wsgi.input": BytesIO(b"payload"), + "wsgi.url_scheme": "http", + "CONTENT_LENGTH": "7", + "HTTP_HOST": "localhost", + } + return Request(environ) + + +def _subscription() -> Subscription: + return Subscription(expires_at=123, endpoint="https://example.com/hook", parameters={"a": 1}, properties={"p": 1}) + + +def _trigger_provider(name: str = "provider") -> SimpleNamespace: + return SimpleNamespace( + plugin_id="org/plugin", + declaration=SimpleNamespace( + identity=SimpleNamespace(name=name), + events=[SimpleNamespace(identity=SimpleNamespace(provider=""))], + ), + ) + + +def _subscription_call_kwargs(method_name: str) -> dict: + if method_name == "subscribe": + return { + "tenant_id": "tenant-1", + "user_id": "user-1", + "provider": "org/plugin/provider", + "credentials": {"token": "x"}, + "credential_type": CredentialType.API_KEY, + "endpoint": "https://example.com/hook", + "parameters": {"k": "v"}, + } + + return { + "tenant_id": "tenant-1", + "user_id": "user-1", + "provider": "org/plugin/provider", + "subscription": _subscription(), + "credentials": {"token": "x"}, + "credential_type": CredentialType.API_KEY, + } + + +class TestPluginTriggerClient: + def test_fetch_trigger_providers(self, mocker): + client = PluginTriggerClient() + provider = _trigger_provider("remote") + + def fake_request(*args, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "plugin_id": "org/plugin", + "provider": "remote", + "declaration": {"events": [{"identity": {"provider": "old"}}]}, + } + ] + } + transformed = transformer(payload) + assert transformed["data"][0]["declaration"]["events"][0]["identity"]["provider"] == "org/plugin/remote" + return [provider] + + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = client.fetch_trigger_providers("tenant-1") + + assert request_mock.call_count == 1 + assert result[0].declaration.identity.name == "org/plugin/remote" + assert result[0].declaration.events[0].identity.provider == "org/plugin/remote" + + def test_fetch_trigger_provider(self, mocker): + client = PluginTriggerClient() + provider = _trigger_provider("provider") + + def fake_request(*args, **kwargs): + transformer = kwargs["transformer"] + payload = {"data": {"declaration": {"events": [{"identity": {"provider": "old"}}]}}} + transformed = transformer(payload) + assert transformed["data"]["declaration"]["events"][0]["identity"]["provider"] == "org/plugin/provider" + return provider + + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = client.fetch_trigger_provider("tenant-1", TriggerProviderID("org/plugin/provider")) + + assert request_mock.call_count == 1 + assert result.declaration.identity.name == "org/plugin/provider" + assert result.declaration.events[0].identity.provider == "org/plugin/provider" + + def test_invoke_trigger_event(self, mocker): + client = PluginTriggerClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(variables={"ok": True}, cancelled=False)]), + ) + + result = client.invoke_trigger_event( + tenant_id="tenant-1", + user_id="user-1", + provider="org/plugin/provider", + event_name="created", + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + request=_request(), + parameters={"k": "v"}, + subscription=_subscription(), + payload={"payload": 1}, + ) + + assert result.variables == {"ok": True} + assert stream_mock.call_count == 1 + + def test_invoke_trigger_event_no_response_raises(self, mocker): + client = PluginTriggerClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="No response received from plugin daemon for invoke trigger"): + client.invoke_trigger_event( + tenant_id="tenant-1", + user_id="user-1", + provider="org/plugin/provider", + event_name="created", + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + request=_request(), + parameters={"k": "v"}, + subscription=_subscription(), + payload={"payload": 1}, + ) + + def test_validate_provider_credentials(self, mocker): + client = PluginTriggerClient() + stream_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response_stream") + + stream_mock.return_value = iter([SimpleNamespace(result=True)]) + assert client.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True + + stream_mock.return_value = iter([]) + with pytest.raises( + ValueError, match="No response received from plugin daemon for validate provider credentials" + ): + client.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) + + def test_dispatch_event(self, mocker): + client = PluginTriggerClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(user_id="u", events=["e"])]), + ) + + result = client.dispatch_event( + tenant_id="tenant-1", + provider="org/plugin/provider", + subscription={"id": "sub"}, + request=_request(), + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result.user_id == "u" + assert stream_mock.call_count == 1 + + stream_mock.return_value = iter([]) + with pytest.raises(ValueError, match="No response received from plugin daemon for dispatch event"): + client.dispatch_event( + tenant_id="tenant-1", + provider="org/plugin/provider", + subscription={"id": "sub"}, + request=_request(), + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + @pytest.mark.parametrize("method_name", ["subscribe", "unsubscribe", "refresh"]) + def test_subscription_operations_success(self, mocker, method_name): + client = PluginTriggerClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(subscription={"id": "sub"})]), + ) + + method = getattr(client, method_name) + result = method(**_subscription_call_kwargs(method_name)) + + assert result.subscription == {"id": "sub"} + assert stream_mock.call_count == 1 + + @pytest.mark.parametrize( + ("method_name", "expected"), + [ + ("subscribe", "No response received from plugin daemon for subscribe"), + ("unsubscribe", "No response received from plugin daemon for unsubscribe"), + ("refresh", "No response received from plugin daemon for refresh"), + ], + ) + def test_subscription_operations_no_response(self, mocker, method_name, expected): + client = PluginTriggerClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + method = getattr(client, method_name) + + with pytest.raises(ValueError, match=expected): + method(**_subscription_call_kwargs(method_name)) diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py index a380149554..c2778f082b 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py @@ -1,72 +1,359 @@ +import json from types import SimpleNamespace from unittest.mock import MagicMock +import pytest +from pydantic import BaseModel + from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from models.model import AppMode -def test_invoke_chat_app_advanced_chat_injects_pause_state_config(mocker): - workflow = MagicMock() - workflow.created_by = "owner-id" +class _Chunk(BaseModel): + value: int - app = MagicMock() - app.mode = AppMode.ADVANCED_CHAT - app.workflow = workflow - mocker.patch( - "core.plugin.backwards_invocation.app.db", - SimpleNamespace(engine=MagicMock()), +class TestBaseBackwardsInvocation: + def test_convert_to_event_stream_with_generator_and_error(self): + def _stream(): + yield _Chunk(value=1) + yield {"x": 2} + yield "ignored" + raise RuntimeError("boom") + + chunks = list(BaseBackwardsInvocation.convert_to_event_stream(_stream())) + + assert len(chunks) == 3 + first = json.loads(chunks[0].decode()) + second = json.loads(chunks[1].decode()) + error = json.loads(chunks[2].decode()) + assert first["data"]["value"] == 1 + assert second["data"]["x"] == 2 + assert error["error"] == "boom" + + def test_convert_to_event_stream_with_non_generator(self): + chunks = list(BaseBackwardsInvocation.convert_to_event_stream({"ok": True})) + payload = json.loads(chunks[0].decode()) + assert payload["data"] == {"ok": True} + assert payload["error"] == "" + + +class TestPluginAppBackwardsInvocation: + def test_fetch_app_info_workflow_path(self, mocker): + workflow = MagicMock() + workflow.features_dict = {"feature": "v"} + workflow.user_input_form.return_value = [{"name": "foo"}] + app = MagicMock(mode=AppMode.WORKFLOW, workflow=workflow) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + mapper = mocker.patch( + "core.plugin.backwards_invocation.app.get_parameters_from_feature_dict", + return_value={"mapped": True}, + ) + + result = PluginAppBackwardsInvocation.fetch_app_info("app-1", "tenant-1") + + assert result == {"data": {"mapped": True}} + mapper.assert_called_once_with(features_dict={"feature": "v"}, user_input_form=[{"name": "foo"}]) + + def test_fetch_app_info_model_config_path(self, mocker): + model_config = MagicMock() + model_config.to_dict.return_value = {"user_input_form": [{"name": "bar"}], "k": "v"} + app = MagicMock(mode=AppMode.COMPLETION, app_model_config=model_config) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + mocker.patch( + "core.plugin.backwards_invocation.app.get_parameters_from_feature_dict", + return_value={"mapped": True}, + ) + + result = PluginAppBackwardsInvocation.fetch_app_info("app-1", "tenant-1") + + assert result["data"] == {"mapped": True} + + @pytest.mark.parametrize( + ("mode", "route_method"), + [ + (AppMode.CHAT, "invoke_chat_app"), + (AppMode.ADVANCED_CHAT, "invoke_chat_app"), + (AppMode.AGENT_CHAT, "invoke_chat_app"), + (AppMode.WORKFLOW, "invoke_workflow_app"), + (AppMode.COMPLETION, "invoke_completion_app"), + ], ) - generator_spy = mocker.patch( - "core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate", - return_value={"result": "ok"}, + def test_invoke_app_routes_by_mode(self, mocker, mode, route_method): + app = MagicMock(mode=mode) + user = MagicMock() + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=user) + route = mocker.patch.object(PluginAppBackwardsInvocation, route_method, return_value={"routed": True}) + + result = PluginAppBackwardsInvocation.invoke_app( + app_id="app", + user_id="user", + tenant_id="tenant", + conversation_id=None, + query="hello", + stream=False, + inputs={"x": 1}, + files=[], + ) + + assert result == {"routed": True} + assert route.call_count == 1 + + def test_invoke_app_uses_end_user_when_user_id_missing(self, mocker): + app = MagicMock(mode=AppMode.WORKFLOW) + end_user = MagicMock() + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + get_or_create = mocker.patch( + "core.plugin.backwards_invocation.app.EndUserService.get_or_create_end_user", + return_value=end_user, + ) + route = mocker.patch.object(PluginAppBackwardsInvocation, "invoke_workflow_app", return_value={"ok": True}) + + result = PluginAppBackwardsInvocation.invoke_app( + app_id="app", + user_id="", + tenant_id="tenant", + conversation_id="", + query=None, + stream=True, + inputs={}, + files=[], + ) + + assert result == {"ok": True} + get_or_create.assert_called_once_with(app) + assert route.call_args.args[1] is end_user + + def test_invoke_app_missing_query_for_chat_raises(self, mocker): + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=MagicMock(mode=AppMode.CHAT)) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=MagicMock()) + + with pytest.raises(ValueError, match="missing query"): + PluginAppBackwardsInvocation.invoke_app( + app_id="app", + user_id="user", + tenant_id="tenant", + conversation_id=None, + query="", + stream=False, + inputs={}, + files=[], + ) + + def test_invoke_app_unexpected_mode_raises(self, mocker): + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=MagicMock(mode="other")) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=MagicMock()) + + with pytest.raises(ValueError, match="unexpected app type"): + PluginAppBackwardsInvocation.invoke_app( + app_id="app", + user_id="user", + tenant_id="tenant", + conversation_id=None, + query="q", + stream=False, + inputs={}, + files=[], + ) + + @pytest.mark.parametrize( + ("mode", "generator_path"), + [ + (AppMode.AGENT_CHAT, "core.plugin.backwards_invocation.app.AgentChatAppGenerator.generate"), + (AppMode.CHAT, "core.plugin.backwards_invocation.app.ChatAppGenerator.generate"), + ], ) + def test_invoke_chat_app_agent_and_chat(self, mocker, mode, generator_path): + app = MagicMock(mode=mode, workflow=None) + spy = mocker.patch(generator_path, return_value={"result": "ok"}) - result = PluginAppBackwardsInvocation.invoke_chat_app( - app=app, - user=MagicMock(), - conversation_id="conv-1", - query="hello", - stream=False, - inputs={"k": "v"}, - files=[], - ) + result = PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={"k": "v"}, + files=[], + ) - assert result == {"result": "ok"} - call_kwargs = generator_spy.call_args.kwargs - pause_state_config = call_kwargs.get("pause_state_config") - assert isinstance(pause_state_config, PauseStateLayerConfig) - assert pause_state_config.state_owner_user_id == "owner-id" + assert result == {"result": "ok"} + assert spy.call_count == 1 + def test_invoke_chat_app_advanced_chat_injects_pause_state_config(self, mocker): + workflow = MagicMock() + workflow.created_by = "owner-id" -def test_invoke_workflow_app_injects_pause_state_config(mocker): - workflow = MagicMock() - workflow.created_by = "owner-id" + app = MagicMock() + app.mode = AppMode.ADVANCED_CHAT + app.workflow = workflow - app = MagicMock() - app.mode = AppMode.WORKFLOW - app.workflow = workflow + mocker.patch( + "core.plugin.backwards_invocation.app.db", + SimpleNamespace(engine=MagicMock()), + ) + generator_spy = mocker.patch( + "core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate", + return_value={"result": "ok"}, + ) - mocker.patch( - "core.plugin.backwards_invocation.app.db", - SimpleNamespace(engine=MagicMock()), - ) - generator_spy = mocker.patch( - "core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate", - return_value={"result": "ok"}, - ) + result = PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={"k": "v"}, + files=[], + ) - result = PluginAppBackwardsInvocation.invoke_workflow_app( - app=app, - user=MagicMock(), - stream=False, - inputs={"k": "v"}, - files=[], - ) + assert result == {"result": "ok"} + call_kwargs = generator_spy.call_args.kwargs + pause_state_config = call_kwargs.get("pause_state_config") + assert isinstance(pause_state_config, PauseStateLayerConfig) + assert pause_state_config.state_owner_user_id == "owner-id" - assert result == {"result": "ok"} - call_kwargs = generator_spy.call_args.kwargs - pause_state_config = call_kwargs.get("pause_state_config") - assert isinstance(pause_state_config, PauseStateLayerConfig) - assert pause_state_config.state_owner_user_id == "owner-id" + def test_invoke_chat_app_advanced_chat_without_workflow_raises(self): + app = MagicMock(mode=AppMode.ADVANCED_CHAT, workflow=None) + with pytest.raises(ValueError, match="unexpected app type"): + PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={}, + files=[], + ) + + def test_invoke_chat_app_unexpected_mode_raises(self): + app = MagicMock(mode="invalid") + with pytest.raises(ValueError, match="unexpected app type"): + PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={}, + files=[], + ) + + def test_invoke_workflow_app_injects_pause_state_config(self, mocker): + workflow = MagicMock() + workflow.created_by = "owner-id" + + app = MagicMock() + app.mode = AppMode.WORKFLOW + app.workflow = workflow + + mocker.patch( + "core.plugin.backwards_invocation.app.db", + SimpleNamespace(engine=MagicMock()), + ) + generator_spy = mocker.patch( + "core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate", + return_value={"result": "ok"}, + ) + + result = PluginAppBackwardsInvocation.invoke_workflow_app( + app=app, + user=MagicMock(), + stream=False, + inputs={"k": "v"}, + files=[], + ) + + assert result == {"result": "ok"} + call_kwargs = generator_spy.call_args.kwargs + pause_state_config = call_kwargs.get("pause_state_config") + assert isinstance(pause_state_config, PauseStateLayerConfig) + assert pause_state_config.state_owner_user_id == "owner-id" + + def test_invoke_workflow_app_without_workflow_raises(self): + app = MagicMock(mode=AppMode.WORKFLOW, workflow=None) + with pytest.raises(ValueError, match="unexpected app type"): + PluginAppBackwardsInvocation.invoke_workflow_app( + app=app, + user=MagicMock(), + stream=False, + inputs={}, + files=[], + ) + + def test_invoke_completion_app(self, mocker): + spy = mocker.patch( + "core.plugin.backwards_invocation.app.CompletionAppGenerator.generate", return_value={"ok": 1} + ) + app = MagicMock(mode=AppMode.COMPLETION) + + result = PluginAppBackwardsInvocation.invoke_completion_app(app, MagicMock(), False, {"x": 1}, []) + + assert result == {"ok": 1} + assert spy.call_count == 1 + + def test_get_user_returns_end_user(self, mocker): + session = MagicMock() + session.scalar.side_effect = [MagicMock(id="end-user")] + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx) + mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock())) + + user = PluginAppBackwardsInvocation._get_user("uid") + assert user.id == "end-user" + + def test_get_user_falls_back_to_account_user(self, mocker): + session = MagicMock() + session.scalar.side_effect = [None, MagicMock(id="account-user")] + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx) + mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock())) + + user = PluginAppBackwardsInvocation._get_user("uid") + assert user.id == "account-user" + + def test_get_user_raises_when_user_not_found(self, mocker): + session = MagicMock() + session.scalar.side_effect = [None, None] + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx) + mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock())) + + with pytest.raises(ValueError, match="user not found"): + PluginAppBackwardsInvocation._get_user("uid") + + def test_get_app_returns_app(self, mocker): + query_chain = MagicMock() + query_chain.where.return_value = query_chain + app_obj = MagicMock(id="app") + query_chain.first.return_value = app_obj + db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + mocker.patch("core.plugin.backwards_invocation.app.db", db) + + assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj + + def test_get_app_raises_when_missing(self, mocker): + query_chain = MagicMock() + query_chain.where.return_value = query_chain + query_chain.first.return_value = None + db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + mocker.patch("core.plugin.backwards_invocation.app.db", db) + + with pytest.raises(ValueError, match="app not found"): + PluginAppBackwardsInvocation._get_app("app", "tenant") + + def test_get_app_raises_when_query_fails(self, mocker): + db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down")))) + mocker.patch("core.plugin.backwards_invocation.app.db", db) + + with pytest.raises(ValueError, match="app not found"): + PluginAppBackwardsInvocation._get_app("app", "tenant") diff --git a/api/tests/unit_tests/core/plugin/test_plugin_entities.py b/api/tests/unit_tests/core/plugin/test_plugin_entities.py new file mode 100644 index 0000000000..b0b64a601b --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_plugin_entities.py @@ -0,0 +1,347 @@ +import binascii +import datetime +from enum import StrEnum + +import pytest +from flask import Response +from pydantic import ValidationError + +from core.plugin.entities.endpoint import EndpointEntityWithInstance +from core.plugin.entities.marketplace import MarketplacePluginDeclaration, MarketplacePluginSnapshot +from core.plugin.entities.parameters import ( + PluginParameter, + PluginParameterOption, + PluginParameterType, + as_normal_type, + cast_parameter_value, + init_frontend_parameter, +) +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.entities.request import ( + RequestInvokeLLM, + RequestInvokeSpeech2Text, + TriggerDispatchResponse, + TriggerInvokeEventResponse, +) +from core.plugin.utils.http_parser import serialize_response +from core.tools.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) + + +class TestEndpointEntity: + def test_endpoint_entity_with_instance_renders_url(self, mocker): + mocker.patch("core.plugin.entities.endpoint.dify_config.ENDPOINT_URL_TEMPLATE", "https://dify.test/{hook_id}") + now = datetime.datetime.now(datetime.UTC) + + entity = EndpointEntityWithInstance.model_validate( + { + "id": "ep-1", + "created_at": now, + "updated_at": now, + "settings": {}, + "tenant_id": "tenant", + "plugin_id": "org/plugin", + "expired_at": now, + "name": "my-endpoint", + "enabled": True, + "hook_id": "hook-123", + } + ) + + assert entity.url == "https://dify.test/hook-123" + + def test_endpoint_entity_with_instance_keeps_existing_url(self): + now = datetime.datetime.now(datetime.UTC) + entity = EndpointEntityWithInstance.model_validate( + { + "id": "ep-1", + "created_at": now, + "updated_at": now, + "settings": {}, + "tenant_id": "tenant", + "plugin_id": "org/plugin", + "expired_at": now, + "name": "my-endpoint", + "enabled": True, + "hook_id": "hook-123", + "url": "https://preset.test/hook-123", + } + ) + assert entity.url == "https://preset.test/hook-123" + + +class TestMarketplaceEntities: + def test_marketplace_declaration_strips_empty_optional_fields(self): + declaration = MarketplacePluginDeclaration.model_validate( + { + "name": "plugin", + "org": "org", + "plugin_id": "org/plugin", + "icon": "icon.png", + "label": {"en_US": "Plugin"}, + "brief": {"en_US": "Brief"}, + "resource": {"memory": 256}, + "endpoint": {}, + "model": {}, + "tool": {}, + "latest_version": "1.0.0", + "latest_package_identifier": "org/plugin@1.0.0", + "status": "active", + "deprecated_reason": "", + "alternative_plugin_id": "", + } + ) + + assert declaration.endpoint is None + assert declaration.model is None + assert declaration.tool is None + + def test_marketplace_snapshot_computed_plugin_id(self): + snapshot = MarketplacePluginSnapshot( + org="langgenius", + name="search", + latest_version="1.0.0", + latest_package_identifier="langgenius/search@1.0.0", + latest_package_url="https://example.com/pkg", + ) + assert snapshot.plugin_id == "langgenius/search" + + +class TestPluginParameterEntities: + def _label(self) -> I18nObject: + return I18nObject(en_US="label") + + def test_parameter_option_value_casts_to_string(self): + option = PluginParameterOption(value=123, label=self._label()) + assert option.value == "123" + + def test_plugin_parameter_options_non_list_defaults_to_empty(self): + parameter = PluginParameter(name="p", label=self._label(), options="invalid") # type: ignore[arg-type] + assert parameter.options == [] + + @pytest.mark.parametrize( + ("parameter_type", "expected"), + [ + (PluginParameterType.SECRET_INPUT, "string"), + (PluginParameterType.SELECT, "string"), + (PluginParameterType.CHECKBOX, "string"), + (PluginParameterType.NUMBER, PluginParameterType.NUMBER.value), + ], + ) + def test_as_normal_type(self, parameter_type, expected): + assert as_normal_type(parameter_type) == expected + + @pytest.mark.parametrize( + ("value", "expected"), + [(None, ""), (1, "1"), ("abc", "abc")], + ) + def test_cast_parameter_value_string_like(self, value, expected): + assert cast_parameter_value(PluginParameterType.STRING, value) == expected + + @pytest.mark.parametrize( + ("value", "expected"), + [ + (None, False), + ("true", True), + ("yes", True), + ("1", True), + ("false", False), + ("0", False), + ("random", True), + (1, True), + (0, False), + ], + ) + def test_cast_parameter_value_boolean(self, value, expected): + assert cast_parameter_value(PluginParameterType.BOOLEAN, value) is expected + + @pytest.mark.parametrize( + ("value", "expected"), + [ + (1, 1), + (1.5, 1.5), + ("2", 2), + ("2.5", 2.5), + ], + ) + def test_cast_parameter_value_number(self, value, expected): + assert cast_parameter_value(PluginParameterType.NUMBER, value) == expected + + def test_cast_parameter_value_file_and_files(self): + assert cast_parameter_value(PluginParameterType.FILES, "f1") == ["f1"] + assert cast_parameter_value(PluginParameterType.SYSTEM_FILES, ["f1", "f2"]) == ["f1", "f2"] + assert cast_parameter_value(PluginParameterType.FILE, ["one"]) == "one" + assert cast_parameter_value(PluginParameterType.FILE, "one") == "one" + with pytest.raises(ValueError, match="only accepts one file"): + cast_parameter_value(PluginParameterType.FILE, ["a", "b"]) + + @pytest.mark.parametrize( + ("parameter_type", "value", "expected"), + [ + (PluginParameterType.MODEL_SELECTOR, {"m": "gpt"}, {"m": "gpt"}), + (PluginParameterType.APP_SELECTOR, {"app": "a"}, {"app": "a"}), + (PluginParameterType.TOOLS_SELECTOR, [], []), + (PluginParameterType.ANY, {"k": "v"}, {"k": "v"}), + ], + ) + def test_cast_parameter_value_selectors_valid(self, parameter_type, value, expected): + assert cast_parameter_value(parameter_type, value) == expected + + @pytest.mark.parametrize( + ("parameter_type", "value", "message"), + [ + (PluginParameterType.MODEL_SELECTOR, "bad", "selector must be a dictionary"), + (PluginParameterType.APP_SELECTOR, "bad", "selector must be a dictionary"), + (PluginParameterType.TOOLS_SELECTOR, "bad", "tools selector must be a list"), + (PluginParameterType.ANY, object(), "var selector must be"), + ], + ) + def test_cast_parameter_value_selectors_invalid(self, parameter_type, value, message): + with pytest.raises(ValueError, match=message): + cast_parameter_value(parameter_type, value) + + @pytest.mark.parametrize( + ("parameter_type", "value", "expected"), + [ + (PluginParameterType.ARRAY, [1, 2], [1, 2]), + (PluginParameterType.ARRAY, "[1, 2]", [1, 2]), + (PluginParameterType.OBJECT, {"k": "v"}, {"k": "v"}), + (PluginParameterType.OBJECT, '{"a":1}', {"a": 1}), + ], + ) + def test_cast_parameter_value_array_and_object_valid(self, parameter_type, value, expected): + assert cast_parameter_value(parameter_type, value) == expected + + @pytest.mark.parametrize( + ("parameter_type", "value", "expected"), + [ + (PluginParameterType.ARRAY, "bad-json", ["bad-json"]), + (PluginParameterType.OBJECT, "bad-json", {}), + ], + ) + def test_cast_parameter_value_array_and_object_invalid_json_fallback(self, parameter_type, value, expected): + assert cast_parameter_value(parameter_type, value) == expected + + def test_cast_parameter_value_default_branch_and_wrapped_exception(self): + class _Unknown(StrEnum): + CUSTOM = "custom" + + assert cast_parameter_value(_Unknown.CUSTOM, 12) == "12" + + class _BadString: + def __str__(self): + raise RuntimeError("boom") + + with pytest.raises( + ValueError, + match=r"The tool parameter value <.*_BadString object at .* is not in correct type of string\.", + ): + cast_parameter_value(PluginParameterType.STRING, _BadString()) + + def test_init_frontend_parameter(self): + rule = PluginParameter( + name="choice", + label=self._label(), + required=True, + default="a", + options=[PluginParameterOption(value="a", label=self._label())], + ) + + assert init_frontend_parameter(rule, PluginParameterType.SELECT, None) == "a" + assert init_frontend_parameter(rule, PluginParameterType.NUMBER, 0) == 0 + with pytest.raises(ValueError, match="not in options"): + init_frontend_parameter(rule, PluginParameterType.SELECT, "b") + + required_rule = PluginParameter(name="required", label=self._label(), required=True, default=None) + with pytest.raises(ValueError, match="not found in tool config"): + init_frontend_parameter(required_rule, PluginParameterType.STRING, None) + + +class TestPluginDaemonEntities: + def test_credential_type_helpers(self): + assert CredentialType.API_KEY.get_name() == "API KEY" + assert CredentialType.OAUTH2.get_name() == "AUTH" + assert CredentialType.UNAUTHORIZED.get_name() == "UNAUTHORIZED" + + class _FakeCredential: + value = "custom-type" + + assert CredentialType.get_name(_FakeCredential()) == "CUSTOM TYPE" + assert CredentialType.API_KEY.is_editable() is True + assert CredentialType.OAUTH2.is_editable() is False + assert CredentialType.API_KEY.is_validate_allowed() is True + assert CredentialType.UNAUTHORIZED.is_validate_allowed() is False + assert set(CredentialType.values()) == {"api-key", "oauth2", "unauthorized"} + + @pytest.mark.parametrize( + ("raw", "expected"), + [ + ("api-key", CredentialType.API_KEY), + ("api_key", CredentialType.API_KEY), + ("oauth2", CredentialType.OAUTH2), + ("oauth", CredentialType.OAUTH2), + ("unauthorized", CredentialType.UNAUTHORIZED), + ], + ) + def test_credential_type_of(self, raw, expected): + assert CredentialType.of(raw) == expected + + def test_credential_type_of_invalid(self): + with pytest.raises(ValueError, match="Invalid credential type"): + CredentialType.of("invalid") + + +class TestPluginRequestEntities: + def test_request_invoke_llm_converts_prompt_messages(self): + payload = RequestInvokeLLM( + provider="openai", + model="gpt-4", + mode="chat", + prompt_messages=[ + {"role": "user", "content": "u"}, + {"role": "assistant", "content": "a"}, + {"role": "system", "content": "s"}, + {"role": "tool", "content": "t", "tool_call_id": "call-1"}, + ], + ) + + assert isinstance(payload.prompt_messages[0], UserPromptMessage) + assert isinstance(payload.prompt_messages[1], AssistantPromptMessage) + assert isinstance(payload.prompt_messages[2], SystemPromptMessage) + assert isinstance(payload.prompt_messages[3], ToolPromptMessage) + + def test_request_invoke_llm_prompt_messages_must_be_list(self): + with pytest.raises(ValidationError): + RequestInvokeLLM(provider="openai", model="gpt-4", mode="chat", prompt_messages="invalid") # type: ignore[arg-type] + + def test_request_invoke_speech2text_hex_conversion_and_error(self): + payload = RequestInvokeSpeech2Text(provider="openai", model="m", file=binascii.hexlify(b"abc").decode()) + assert payload.file == b"abc" + with pytest.raises(ValidationError): + RequestInvokeSpeech2Text(provider="openai", model="m", file=b"abc") # type: ignore[arg-type] + + def test_trigger_invoke_event_response_variables_conversion(self): + converted = TriggerInvokeEventResponse(variables='{"a": 1}', cancelled=False) + assert converted.variables == {"a": 1} + passthrough = TriggerInvokeEventResponse(variables={"b": 2}, cancelled=True) + assert passthrough.variables == {"b": 2} + + def test_trigger_dispatch_response_convert_response(self): + response = Response("ok", status=202, headers={"X-Req": "1"}) + encoded = binascii.hexlify(serialize_response(response)).decode() + parsed = TriggerDispatchResponse(user_id="u", events=["e"], response=encoded) + assert parsed.response.status_code == 202 + assert parsed.response.get_data() == b"ok" + with pytest.raises(ValidationError): + TriggerDispatchResponse(user_id="u", events=["e"], response="not-hex") + + def test_trigger_dispatch_response_payload_default(self): + response = Response("ok", status=200) + encoded = binascii.hexlify(serialize_response(response)).decode() + parsed = TriggerDispatchResponse(user_id="u", events=["e"], response=encoded) + assert parsed.payload == {} diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index e0eace0f2d..c7e94aa4cf 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -4,7 +4,10 @@ import pytest from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks -from core.tools.entities.tool_entities import ToolInvokeMessage +from core.plugin.utils.converter import convert_parameters_to_plugin_format +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File class TestChunkMerger: @@ -458,3 +461,89 @@ class TestChunkMerger: assert len(result) == 1 assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage) assert result[0].message.blob == b"FirstSecondThird" + + +class TestConverter: + def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self): + file_param = File( + tenant_id="tenant-1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/file.png", + storage_key="", + ) + selector = ToolSelector( + provider_id="org/plugin/provider", + credential_id=None, + tool_name="search", + tool_description="search tool", + tool_configuration={"k": "v"}, + tool_parameters={ + "query": ToolSelector.Parameter( + name="query", + type=ToolParameter.ToolParameterType.STRING, + required=True, + description="query", + default="python", + options=[], + ) + }, + ) + params = {"file": file_param, "selector": selector, "plain": 123} + + converted = convert_parameters_to_plugin_format(params) + + assert converted["file"]["url"] == "https://example.com/file.png" + assert converted["selector"]["provider_id"] == "org/plugin/provider" + assert converted["plain"] == 123 + + def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self): + file_one = File( + tenant_id="tenant-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/a.txt", + storage_key="", + ) + file_two = File( + tenant_id="tenant-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/b.txt", + storage_key="", + ) + selector_one = ToolSelector( + provider_id="org/plugin/provider", + credential_id="cred-1", + tool_name="t1", + tool_description="tool 1", + tool_configuration={}, + tool_parameters={}, + ) + selector_two = ToolSelector( + provider_id="org/plugin/provider", + credential_id="cred-2", + tool_name="t2", + tool_description="tool 2", + tool_configuration={}, + tool_parameters={}, + ) + + params = { + "files": [file_one, file_two], + "selectors": [selector_one, selector_two], + "empty_list": [], + "mixed_list": [file_one, "raw"], + "none_value": None, + } + + converted = convert_parameters_to_plugin_format(params) + + assert [item["url"] for item in converted["files"]] == [ + "https://example.com/a.txt", + "https://example.com/b.txt", + ] + assert [item["tool_name"] for item in converted["selectors"]] == ["t1", "t2"] + assert converted["empty_list"] == [] + assert converted["mixed_list"] == [file_one, "raw"] + assert converted["none_value"] is None diff --git a/api/tests/unit_tests/core/plugin/utils/test_http_parser.py b/api/tests/unit_tests/core/plugin/utils/test_http_parser.py index 1c2e0c96f8..71144695bc 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_http_parser.py +++ b/api/tests/unit_tests/core/plugin/utils/test_http_parser.py @@ -381,6 +381,54 @@ class TestEdgeCases: assert response.status_code == 200 assert response.get_data() == binary_body + def test_deserialize_request_with_lf_only_newlines(self): + raw_data = b"POST /lf-only?x=1 HTTP/1.1\nHost: localhost\nX-Test: yes\n\npayload" + + request = deserialize_request(raw_data) + + assert request.method == "POST" + assert request.path == "/lf-only" + assert request.args.get("x") == "1" + assert request.headers.get("X-Test") == "yes" + assert request.get_data() == b"payload" + + def test_deserialize_request_without_header_separator_uses_full_input_as_headers(self): + raw_data = b"GET /no-separator HTTP/1.1\nHost: localhost\nInvalidHeader\n" + + request = deserialize_request(raw_data) + + assert request.method == "GET" + assert request.path == "/no-separator" + assert request.headers.get("Host") == "localhost" + assert request.headers.get("InvalidHeader") is None + + def test_deserialize_request_empty_payload_raises(self): + with pytest.raises(ValueError, match="Empty HTTP request"): + deserialize_request(b"") + + def test_deserialize_response_with_lf_only_newlines(self): + raw_data = b"HTTP/1.1 202 Accepted\nX-Test: yes\n\nbody" + + response = deserialize_response(raw_data) + + assert response.status_code == 202 + assert response.headers.get("X-Test") == "yes" + assert response.get_data() == b"body" + + def test_deserialize_response_without_header_separator_uses_full_input_as_headers(self): + raw_data = b"HTTP/1.1 204 No Content\nX-Test: yes\nInvalidHeader\n" + + response = deserialize_response(raw_data) + + assert response.status_code == 204 + assert response.headers.get("X-Test") == "yes" + assert response.headers.get("InvalidHeader") is None + assert response.get_data() == b"" + + def test_deserialize_response_empty_payload_raises(self): + with pytest.raises(ValueError, match="Empty HTTP response"): + deserialize_response(b"") + class TestFileUploads: def test_serialize_request_with_text_file_upload(self): diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 3e184cbf21..3d08525aba 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -1,3 +1,4 @@ +from typing import cast from unittest.mock import MagicMock, patch import pytest @@ -13,6 +14,8 @@ from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, UserPromptMessage, ) from models.model import Conversation @@ -188,3 +191,328 @@ def get_chat_model_args(): context = "I am superman." return model_config_mock, memory_config, prompt_messages, inputs, context + + +def test_get_prompt_dispatches_completion_and_chat_and_invalid(): + transform = AdvancedPromptTransform() + model_config = MagicMock(spec=ModelConfigEntity) + completion_template = CompletionModelPromptTemplate(text="Hello {{name}}", edition_type="basic") + chat_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="basic")] + + transform._get_completion_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="c")]) + transform._get_chat_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="h")]) + + completion_result = transform.get_prompt( + prompt_template=completion_template, + inputs={"name": "john"}, + query="q", + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config, + ) + assert completion_result[0].content == "c" + + chat_result = transform.get_prompt( + prompt_template=chat_template, + inputs={"name": "john"}, + query="q", + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config, + ) + assert chat_result[0].content == "h" + + invalid_result = transform.get_prompt( + prompt_template=cast(list, ["not-chat-model-message"]), + inputs={"name": "john"}, + query="q", + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config, + ) + assert invalid_result == [] + + +def test_completion_prompt_jinja2_with_files(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + transform = AdvancedPromptTransform() + completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2") + + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + storage_key="", + ) + + with ( + patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hi John"), + patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content, + ): + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_completion_model_prompt_messages( + prompt_template=completion_template, + inputs={"name": "John"}, + query="", + files=[file], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert len(messages) == 1 + assert isinstance(messages[0].content, list) + assert messages[0].content[0].data == "https://example.com/image.jpg" + assert isinstance(messages[0].content[1], TextPromptMessageContent) + assert messages[0].content[1].data == "Hi John" + + +def test_completion_prompt_basic_sets_query_variable(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + transform = AdvancedPromptTransform() + template = CompletionModelPromptTemplate(text="Q={{#query#}}", edition_type="basic") + + messages = transform._get_completion_model_prompt_messages( + prompt_template=template, + inputs={}, + query="what?", + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert messages[0].content == "Q=what?" + + +def test_chat_prompt_with_variable_template_and_context(): + transform = AdvancedPromptTransform(with_variable_tmpl=True) + model_config_mock = MagicMock(spec=ModelConfigEntity) + prompt_template = [ChatModelMessage(text="sys={{#node.name#}} ctx={{#context#}}", role=PromptMessageRole.SYSTEM)] + + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs={"#node.name#": "john"}, + query=None, + files=[], + context="context-text", + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert len(messages) == 1 + assert isinstance(messages[0], SystemPromptMessage) + assert messages[0].content == "sys=john ctx=context-text" + + +def test_chat_prompt_jinja2_branch_and_invalid_edition(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + prompt_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="jinja2")] + + with patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hello John"): + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs={"name": "John"}, + query=None, + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + assert messages[0].content == "Hello John" + + bad_prompt_template = [ChatModelMessage.model_construct(text="bad", role=PromptMessageRole.USER, edition_type="x")] + with pytest.raises(ValueError, match="Invalid edition type"): + transform._get_chat_model_prompt_messages( + prompt_template=bad_prompt_template, + inputs={}, + query=None, + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + +def test_chat_prompt_query_template_and_query_only_branch(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + memory_config = MemoryConfig( + window=MemoryConfig.WindowConfig(enabled=False), + query_prompt_template="query={{#sys.query#}} ctx={{#context#}}", + ) + prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)] + + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs={}, + query="what", + files=[], + context="ctx", + memory_config=memory_config, + memory=None, + model_config=model_config_mock, + ) + assert messages[-1].content == "query={{#sys.query#}} ctx=ctx" + + +def test_chat_prompt_memory_with_files_and_query(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) + memory = MagicMock(spec=TokenBufferMemory) + prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)] + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + storage_key="", + ) + + transform._append_chat_histories = MagicMock( + side_effect=lambda memory, memory_config, prompt_messages, **kwargs: prompt_messages + ) + with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs={}, + query="q", + files=[file], + context=None, + memory_config=memory_config, + memory=memory, + model_config=model_config_mock, + ) + + assert isinstance(messages[-1].content, list) + assert messages[-1].content[1].data == "q" + + +def test_chat_prompt_files_without_query_updates_last_user_or_appends_new(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + storage_key="", + ) + + prompt_with_last_user = [ChatModelMessage(text="u", role=PromptMessageRole.USER)] + with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_with_last_user, + inputs={}, + query=None, + files=[file], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + assert isinstance(messages[-1].content, list) + assert messages[-1].content[1].data == "u" + + prompt_without_last_user = [ChatModelMessage(text="s", role=PromptMessageRole.SYSTEM)] + with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_without_last_user, + inputs={}, + query=None, + files=[file], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + assert isinstance(messages[-1], UserPromptMessage) + assert isinstance(messages[-1].content, list) + assert messages[-1].content[1].data == "" + + +def test_chat_prompt_files_with_query_branch(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + storage_key="", + ) + + with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_chat_model_prompt_messages( + prompt_template=[], + inputs={}, + query="query-text", + files=[file], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert isinstance(messages[-1].content, list) + assert messages[-1].content[1].data == "query-text" + + +def test_set_context_query_histories_variable_helpers(): + transform = AdvancedPromptTransform() + parser_context = PromptTemplateParser(template="{{#context#}}") + parser_query = PromptTemplateParser(template="{{#query#}}") + parser_hist = PromptTemplateParser(template="{{#histories#}}") + model_config_mock = MagicMock(spec=ModelConfigEntity) + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=False), + ) + + assert transform._set_context_variable(None, parser_context, {})["#context#"] == "" + assert transform._set_query_variable("", parser_query, {})["#query#"] == "" + assert transform._set_query_variable("x", parser_query, {})["#query#"] == "x" + assert ( + transform._set_histories_variable( + memory=None, # type: ignore[arg-type] + memory_config=memory_config, + raw_prompt="{{#histories#}}", + role_prefix=memory_config.role_prefix, # type: ignore[arg-type] + parser=parser_hist, + prompt_inputs={}, + model_config=model_config_mock, + )["#histories#"] + == "" + ) diff --git a/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py b/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py index e3e500e310..1b114b369a 100644 --- a/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py +++ b/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py @@ -2,12 +2,14 @@ from uuid import uuid4 from constants import UUID_NIL from core.prompt.utils.extract_thread_messages import extract_thread_messages +from core.prompt.utils.get_thread_messages_length import get_thread_messages_length class MockMessage: - def __init__(self, id, parent_message_id): + def __init__(self, id, parent_message_id, answer="answer"): self.id = id self.parent_message_id = parent_message_id + self.answer = answer def __getitem__(self, item): return getattr(self, item) @@ -89,3 +91,44 @@ def test_extract_thread_messages_mixed_with_legacy_messages(): result = extract_thread_messages(messages) assert len(result) == 4 assert [msg["id"] for msg in result] == [id5, id4, id2, id1] + + +def test_extract_thread_messages_breaks_when_parent_is_none(): + id1, id2 = str(uuid4()), str(uuid4()) + messages = [MockMessage(id2, None), MockMessage(id1, UUID_NIL)] + + result = extract_thread_messages(messages) + + assert len(result) == 1 + assert result[0].id == id2 + + +def test_get_thread_messages_length_excludes_newly_created_empty_answer(mocker): + id1, id2 = str(uuid4()), str(uuid4()) + messages = [ + MockMessage(id2, id1, answer=""), # newest generated message should be excluded + MockMessage(id1, UUID_NIL, answer="ok"), + ] + + mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars") + mock_scalars.return_value.all.return_value = messages + + length = get_thread_messages_length("conversation-1") + + assert length == 1 + mock_scalars.assert_called_once() + + +def test_get_thread_messages_length_keeps_non_empty_latest_answer(mocker): + id1, id2 = str(uuid4()), str(uuid4()) + messages = [ + MockMessage(id2, id1, answer="latest-answer"), + MockMessage(id1, UUID_NIL, answer="older-answer"), + ] + + mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars") + mock_scalars.return_value.all.return_value = messages + + length = get_thread_messages_length("conversation-2") + + assert length == 2 diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index 4136816562..9fc300348a 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,6 +1,11 @@ +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + AudioPromptMessageContent, ImagePromptMessageContent, TextPromptMessageContent, + ToolPromptMessage, UserPromptMessage, ) @@ -25,3 +30,82 @@ def test_dump_prompt_message(): ) data = prompt.model_dump() assert data["content"][0].get("url") == example_url + + +def test_prompt_messages_to_prompt_for_saving_chat_mode(): + chat_messages = [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="hello "), + ImagePromptMessageContent( + url="https://example.com/image1.jpg", + format="jpg", + mime_type="image/jpeg", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + AudioPromptMessageContent( + url="https://example.com/audio1.mp3", + format="mp3", + mime_type="audio/mpeg", + ), + TextPromptMessageContent(data="world"), + ] + ), + AssistantPromptMessage( + content="assistant-text", + tool_calls=[ + { + "id": "tool-1", + "type": "function", + "function": {"name": "search", "arguments": '{"q":"python"}'}, + } + ], + ), + ToolPromptMessage(content="tool-output", name="search", tool_call_id="tool-1"), + UserPromptMessage.model_construct(role="unknown", content="skip"), # type: ignore[arg-type] + ] + + prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(ModelMode.CHAT, chat_messages) + + assert len(prompts) == 3 + assert prompts[0]["role"] == "user" + assert prompts[0]["text"] == "hello world" + assert prompts[0]["files"][0]["type"] == "image" + assert prompts[0]["files"][1]["type"] == "audio" + + assert prompts[1]["role"] == "assistant" + assert prompts[1]["text"] == "assistant-text" + assert prompts[1]["tool_calls"][0]["function"]["name"] == "search" + assert prompts[2]["role"] == "tool" + + +def test_prompt_messages_to_prompt_for_saving_completion_mode_with_and_without_files(): + completion_message_with_files = UserPromptMessage( + content=[ + TextPromptMessageContent(data="first "), + TextPromptMessageContent(data="second"), + ImagePromptMessageContent( + url="https://example.com/image2.jpg", + format="jpg", + mime_type="image/jpeg", + detail=ImagePromptMessageContent.DETAIL.LOW, + ), + ] + ) + prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + ModelMode.COMPLETION, [completion_message_with_files] + ) + assert prompts == [ + { + "role": "user", + "text": "first second", + "files": prompts[0]["files"], + } + ] + assert prompts[0]["files"][0]["type"] == "image" + + completion_message_text_only = UserPromptMessage(content="plain text") + prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + ModelMode.COMPLETION, [completion_message_text_only] + ) + assert prompts == [{"role": "user", "text": "plain text"}] diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 7976120547..d379e3067a 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -1,4 +1,10 @@ -# from unittest.mock import MagicMock +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.prompt.prompt_transform import PromptTransform +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle @@ -9,44 +15,217 @@ # from core.prompt.prompt_transform import PromptTransform -# def test__calculate_rest_token(): -# model_schema_mock = MagicMock(spec=AIModelEntity) -# parameter_rule_mock = MagicMock(spec=ParameterRule) -# parameter_rule_mock.name = "max_tokens" -# model_schema_mock.parameter_rules = [parameter_rule_mock] -# model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62} +class TestPromptTransform: + def test_resolve_model_runtime_requires_model_config_or_instance(self): + transform = PromptTransform() -# large_language_model_mock = MagicMock(spec=LargeLanguageModel) -# large_language_model_mock.get_num_tokens.return_value = 6 + with pytest.raises(ValueError, match="Either model_config or model_instance must be provided."): + transform._resolve_model_runtime() -# provider_mock = MagicMock(spec=ProviderEntity) -# provider_mock.provider = "openai" + def test_resolve_model_runtime_builds_model_instance_from_model_config(self): + transform = PromptTransform() + fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[]) + fake_model_type_instance = MagicMock() + fake_model_type_instance.get_model_schema.return_value = fake_model_schema + fake_model_instance = SimpleNamespace( + model_type_instance=fake_model_type_instance, + model_name="resolved-model", + credentials=None, + parameters=None, + stop=None, + ) + model_config = SimpleNamespace( + provider_model_bundle=object(), + model="config-model", + credentials={"api_key": "secret"}, + parameters={"temperature": 0.1}, + stop=["END"], + model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]), + ) -# provider_configuration_mock = MagicMock(spec=ProviderConfiguration) -# provider_configuration_mock.provider = provider_mock -# provider_configuration_mock.model_settings = None + with patch( + "core.prompt.prompt_transform.ModelInstance", return_value=fake_model_instance + ) as model_instance_cls: + model_instance, model_schema = transform._resolve_model_runtime(model_config=model_config) -# provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) -# provider_model_bundle_mock.model_type_instance = large_language_model_mock -# provider_model_bundle_mock.configuration = provider_configuration_mock + model_instance_cls.assert_called_once_with( + provider_model_bundle=model_config.provider_model_bundle, + model=model_config.model, + ) + fake_model_type_instance.get_model_schema.assert_called_once_with( + model="resolved-model", + credentials={"api_key": "secret"}, + ) + assert model_instance is fake_model_instance + assert model_instance.credentials == {"api_key": "secret"} + assert model_instance.parameters == {"temperature": 0.1} + assert model_instance.stop == ["END"] + assert model_schema is fake_model_schema -# model_config_mock = MagicMock(spec=ModelConfigEntity) -# model_config_mock.model = "gpt-4" -# model_config_mock.credentials = {} -# model_config_mock.parameters = {"max_tokens": 50} -# model_config_mock.model_schema = model_schema_mock -# model_config_mock.provider_model_bundle = provider_model_bundle_mock + def test_resolve_model_runtime_uses_model_config_schema_fallback(self): + transform = PromptTransform() + fallback_schema = SimpleNamespace(model_properties={}, parameter_rules=[]) + fake_model_type_instance = MagicMock() + fake_model_type_instance.get_model_schema.return_value = None + model_instance = SimpleNamespace( + model_type_instance=fake_model_type_instance, + model_name="resolved-model", + credentials={"api_key": "secret"}, + parameters={}, + ) + model_config = SimpleNamespace(model_schema=fallback_schema) -# prompt_transform = PromptTransform() + resolved_model_instance, resolved_schema = transform._resolve_model_runtime( + model_config=model_config, + model_instance=model_instance, + ) -# prompt_messages = [UserPromptMessage(content="Hello, how are you?")] -# rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock) + assert resolved_model_instance is model_instance + assert resolved_schema is fallback_schema -# # Validate based on the mock configuration and expected logic -# expected_rest_tokens = ( -# model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] -# - model_config_mock.parameters["max_tokens"] -# - large_language_model_mock.get_num_tokens.return_value -# ) -# assert rest_tokens == expected_rest_tokens -# assert rest_tokens == 6 + def test_resolve_model_runtime_raises_when_schema_missing_without_model_config(self): + transform = PromptTransform() + fake_model_type_instance = MagicMock() + fake_model_type_instance.get_model_schema.return_value = None + model_instance = SimpleNamespace( + model_type_instance=fake_model_type_instance, + model_name="resolved-model", + credentials={"api_key": "secret"}, + parameters={}, + ) + + with pytest.raises(ValueError, match="Model schema not found for the provided model instance."): + transform._resolve_model_runtime(model_instance=model_instance) + + def test_calculate_rest_token_defaults_when_context_size_missing(self): + transform = PromptTransform() + fake_model_instance = SimpleNamespace(parameters={}, get_llm_num_tokens=lambda _: 0) + fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[]) + transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema)) + model_config = SimpleNamespace( + model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]), + provider_model_bundle=object(), + model="test-model", + parameters={}, + ) + + rest = transform._calculate_rest_token([], model_config=model_config) + + assert rest == 2000 + + def test_calculate_rest_token_uses_max_tokens_and_clamps_to_zero(self): + transform = PromptTransform() + + parameter_rule = SimpleNamespace(name="max_tokens", use_template=None) + fake_model_instance = SimpleNamespace(parameters={"max_tokens": 50}, get_llm_num_tokens=lambda _: 95) + fake_model_schema = SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 100}, + parameter_rules=[parameter_rule], + ) + transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema)) + model_config = SimpleNamespace( + model_schema=SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 100}, + parameter_rules=[parameter_rule], + ), + provider_model_bundle=object(), + model="test-model", + parameters={"max_tokens": 50}, + ) + + rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config) + + assert rest == 0 + + def test_calculate_rest_token_supports_use_template_parameter(self): + transform = PromptTransform() + + parameter_rule = SimpleNamespace(name="generation_max", use_template="max_tokens") + fake_model_instance = SimpleNamespace(parameters={"max_tokens": 30}, get_llm_num_tokens=lambda _: 20) + fake_model_schema = SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 200}, + parameter_rules=[parameter_rule], + ) + transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema)) + model_config = SimpleNamespace( + model_schema=SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 200}, + parameter_rules=[parameter_rule], + ), + provider_model_bundle=object(), + model="test-model", + parameters={"max_tokens": 30}, + ) + + rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config) + + assert rest == 150 + + def test_get_history_messages_from_memory_with_and_without_window(self): + transform = PromptTransform() + memory = MagicMock() + memory.get_history_prompt_text.return_value = "history" + + memory_config_with_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=3)) + result = transform._get_history_messages_from_memory( + memory=memory, + memory_config=memory_config_with_window, + max_token_limit=100, + human_prefix="Human", + ai_prefix="Assistant", + ) + + assert result == "history" + memory.get_history_prompt_text.assert_called_with( + max_token_limit=100, + human_prefix="Human", + ai_prefix="Assistant", + message_limit=3, + ) + + memory.reset_mock() + memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=False, size=2)) + transform._get_history_messages_from_memory( + memory=memory, + memory_config=memory_config_no_window, + max_token_limit=50, + ) + memory.get_history_prompt_text.assert_called_with(max_token_limit=50) + + def test_get_history_messages_list_from_memory_with_and_without_window(self): + transform = PromptTransform() + memory = MagicMock() + memory.get_history_prompt_messages.return_value = ["m1", "m2"] + + memory_config_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=2)) + result = transform._get_history_messages_list_from_memory(memory, memory_config_window, 120) + assert result == ["m1", "m2"] + memory.get_history_prompt_messages.assert_called_with(max_token_limit=120, message_limit=2) + + memory.reset_mock() + memory.get_history_prompt_messages.return_value = ["only"] + memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=0)) + result = transform._get_history_messages_list_from_memory(memory, memory_config_no_window, 10) + assert result == ["only"] + memory.get_history_prompt_messages.assert_called_with(max_token_limit=10, message_limit=None) + + def test_append_chat_histories_extends_prompt_messages(self, monkeypatch): + transform = PromptTransform() + memory = MagicMock() + memory_config = SimpleNamespace(window=SimpleNamespace(enabled=False, size=None)) + + monkeypatch.setattr(transform, "_calculate_rest_token", lambda prompt_messages, **kwargs: 99) + monkeypatch.setattr( + transform, + "_get_history_messages_list_from_memory", + lambda memory, memory_config, max_token_limit: ["h1", "h2"], + ) + + result = transform._append_chat_histories( + memory=memory, + memory_config=memory_config, + prompt_messages=["p1"], + model_config=SimpleNamespace(), + ) + + assert result == ["p1", "h1", "h2"] diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 2ef66e8a96..e6d28224d7 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -1,9 +1,29 @@ -from unittest.mock import MagicMock +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.prompt_templates.advanced_prompt_templates import ( + BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_CONTEXT, + CHAT_APP_CHAT_PROMPT_CONFIG, + CHAT_APP_COMPLETION_PROMPT_CONFIG, + COMPLETION_APP_CHAT_PROMPT_CONFIG, + COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + CONTEXT, +) from core.prompt.simple_prompt_transform import SimplePromptTransform -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) from models.model import AppMode, Conversation @@ -244,3 +264,178 @@ def test__get_completion_model_prompt_messages(): assert len(prompt_messages) == 1 assert stops == prompt_rules.get("stops") assert prompt_messages[0].content == real_prompt + + +def test_get_prompt_dispatches_chat_and_completion(): + transform = SimplePromptTransform() + model_config_chat = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config_chat.mode = "chat" + model_config_completion = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config_completion.mode = "completion" + prompt_entity = SimpleNamespace(simple_prompt_template="hello") + + transform._get_chat_model_prompt_messages = MagicMock(return_value=(["chat-msg"], None)) + transform._get_completion_model_prompt_messages = MagicMock(return_value=(["completion-msg"], ["stop"])) + + chat_messages, chat_stops = transform.get_prompt( + app_mode=AppMode.CHAT, + prompt_template_entity=prompt_entity, + inputs={"n": 1}, + query="q", + files=[], + context=None, + memory=None, + model_config=model_config_chat, + ) + assert chat_messages == ["chat-msg"] + assert chat_stops is None + + completion_messages, completion_stops = transform.get_prompt( + app_mode=AppMode.CHAT, + prompt_template_entity=prompt_entity, + inputs={"n": 1}, + query="q", + files=[], + context=None, + memory=None, + model_config=model_config_completion, + ) + assert completion_messages == ["completion-msg"] + assert completion_stops == ["stop"] + + +def test_get_prompt_str_and_rules_type_validation_errors(): + transform = SimplePromptTransform() + model_config = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config.provider = "openai" + model_config.model = "gpt-4" + valid_prompt_template = SimplePromptTransform().get_prompt_template( + AppMode.CHAT, "openai", "gpt-4", "", False, False + )["prompt_template"] + + bad_custom_keys = { + "prompt_template": valid_prompt_template, + "custom_variable_keys": "not-list", + "special_variable_keys": [], + "prompt_rules": {}, + } + transform.get_prompt_template = MagicMock(return_value=bad_custom_keys) + with pytest.raises(TypeError, match="custom_variable_keys"): + transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None) + + bad_special_keys = { + **bad_custom_keys, + "custom_variable_keys": [], + "special_variable_keys": "not-list", + } + transform.get_prompt_template = MagicMock(return_value=bad_special_keys) + with pytest.raises(TypeError, match="special_variable_keys"): + transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None) + + bad_prompt_template = { + **bad_custom_keys, + "custom_variable_keys": [], + "special_variable_keys": [], + "prompt_template": 123, + } + transform.get_prompt_template = MagicMock(return_value=bad_prompt_template) + with pytest.raises(TypeError, match="PromptTemplateParser"): + transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None) + + bad_prompt_rules = { + **bad_custom_keys, + "custom_variable_keys": [], + "special_variable_keys": [], + "prompt_template": valid_prompt_template, + "prompt_rules": "not-dict", + } + transform.get_prompt_template = MagicMock(return_value=bad_prompt_rules) + with pytest.raises(TypeError, match="prompt_rules"): + transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None) + + +def test_chat_model_prompt_messages_uses_prompt_when_query_empty(): + transform = SimplePromptTransform() + model_config = MagicMock(spec=ModelConfigWithCredentialsEntity) + transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt-text", {})) + transform._get_last_user_message = MagicMock(return_value=UserPromptMessage(content="prompt-text")) + + prompt_messages, _ = transform._get_chat_model_prompt_messages( + app_mode=AppMode.CHAT, + pre_prompt="", + inputs={}, + query="", + files=[], + context=None, + memory=None, + model_config=model_config, + ) + + assert prompt_messages[0].content == "prompt-text" + transform._get_last_user_message.assert_called_once_with("prompt-text", [], None, None) + + +def test_completion_model_prompt_messages_empty_stops_becomes_none(): + transform = SimplePromptTransform() + model_config = MagicMock(spec=ModelConfigWithCredentialsEntity) + transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt", {"stops": []})) + + prompt_messages, stops = transform._get_completion_model_prompt_messages( + app_mode=AppMode.CHAT, + pre_prompt="", + inputs={}, + query="q", + files=[], + context=None, + memory=None, + model_config=model_config, + ) + + assert len(prompt_messages) == 1 + assert stops is None + + +def test_get_last_user_message_with_files_and_context_files(): + transform = SimplePromptTransform() + file = SimpleNamespace() + context_file = SimpleNamespace() + + with patch("core.prompt.simple_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.side_effect = [ + ImagePromptMessageContent(url="https://example.com/a.jpg", format="jpg", mime_type="image/jpg"), + ImagePromptMessageContent(url="https://example.com/b.jpg", format="jpg", mime_type="image/jpg"), + ] + message = transform._get_last_user_message( + prompt="hello", + files=[file], + context_files=[context_file], + image_detail_config=None, + ) + + assert isinstance(message.content, list) + assert message.content[0].data == "https://example.com/a.jpg" + assert message.content[1].data == "https://example.com/b.jpg" + assert isinstance(message.content[2], TextPromptMessageContent) + assert message.content[2].data == "hello" + + +def test_prompt_file_name_branches(): + transform = SimplePromptTransform() + + assert transform._prompt_file_name(AppMode.CHAT, "openai", "gpt-4") == "common_chat" + assert transform._prompt_file_name(AppMode.COMPLETION, "openai", "gpt-4") == "common_completion" + assert transform._prompt_file_name(AppMode.COMPLETION, "baichuan", "Baichuan2") == "baichuan_completion" + assert transform._prompt_file_name(AppMode.CHAT, "huggingface_hub", "baichuan-13b") == "baichuan_chat" + + +def test_advanced_prompt_templates_constants_are_importable(): + assert isinstance(CONTEXT, str) + assert isinstance(BAICHUAN_CONTEXT, str) + assert "completion_prompt_config" in CHAT_APP_COMPLETION_PROMPT_CONFIG + assert "chat_prompt_config" in CHAT_APP_CHAT_PROMPT_CONFIG + assert "chat_prompt_config" in COMPLETION_APP_CHAT_PROMPT_CONFIG + assert "completion_prompt_config" in COMPLETION_APP_COMPLETION_PROMPT_CONFIG + assert "completion_prompt_config" in BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG + assert "chat_prompt_config" in BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG + assert "chat_prompt_config" in BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG + assert "completion_prompt_config" in BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index b90c4935af..de3ccc4518 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -3730,7 +3730,7 @@ class TestDatasetRetrievalAdditionalHelpers: attachment_ids=None, dataset_ids=["d1"], app_id="a1", - user_from="web", + user_from="account", user_id="u1", ) mock_session.add_all.assert_not_called() @@ -3740,7 +3740,7 @@ class TestDatasetRetrievalAdditionalHelpers: attachment_ids=["f1"], dataset_ids=["d1", "d2"], app_id="a1", - user_from="web", + user_from="account", user_id="u1", ) mock_session.add_all.assert_called() diff --git a/api/tests/unit_tests/core/tools/__init__.py b/api/tests/unit_tests/core/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py new file mode 100644 index 0000000000..f123f60a34 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections.abc import Generator +from types import SimpleNamespace +from typing import Any +from unittest.mock import patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType +from dify_graph.model_runtime.entities.message_entities import UserPromptMessage + + +class _BuiltinDummyTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("ok") + + +def _build_tool() -> _BuiltinDummyTool: + entity = ToolEntity( + identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), + parameters=[], + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime) + + +def test_builtin_tool_fork_and_provider_type(): + tool = _build_tool() + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + assert isinstance(forked, _BuiltinDummyTool) + assert forked.runtime.tenant_id == "tenant-2" + assert tool.tool_provider_type() == ToolProviderType.BUILT_IN + + +def test_invoke_model_calls_model_invocation_utils_invoke(): + tool = _build_tool() + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke: + assert ( + tool.invoke_model( + user_id="u1", + prompt_messages=[UserPromptMessage(content="hello")], + stop=[], + ) + == "result" + ) + mock_invoke.assert_called_once() + + +def test_get_max_tokens_returns_value(): + tool = _build_tool() + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096): + assert tool.get_max_tokens() == 4096 + + +def test_get_prompt_tokens_returns_value(): + tool = _build_tool() + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7): + assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + + +def test_runtime_none_raises(): + tool = _build_tool() + tool.runtime = None + with pytest.raises(ValueError, match="runtime is required"): + tool.get_max_tokens() + with pytest.raises(ValueError, match="runtime is required"): + tool.get_prompt_tokens([UserPromptMessage(content="hello")]) + + +def test_builtin_tool_summary_short_and_long_content_paths(): + tool = _build_tool() + + with patch.object(_BuiltinDummyTool, "get_max_tokens", return_value=100): + with patch.object(_BuiltinDummyTool, "get_prompt_tokens", return_value=10): + assert tool.summary(user_id="u1", content="short") == "short" + + with patch.object(_BuiltinDummyTool, "get_max_tokens", return_value=10): + with patch.object( + _BuiltinDummyTool, + "get_prompt_tokens", + side_effect=lambda prompt_messages: len(prompt_messages[-1].content), + ): + with patch.object( + _BuiltinDummyTool, + "invoke_model", + return_value=SimpleNamespace(message=SimpleNamespace(content="S")), + ): + result = tool.summary(user_id="u1", content="x" * 30 + "\n" + "y" * 5) + + assert result + assert "S" in result diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_provider.py b/api/tests/unit_tests/core/tools/test_builtin_tool_provider.py new file mode 100644 index 0000000000..ad6d5906ae --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_provider.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import Any +from unittest.mock import patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.plugin.entities.plugin_daemon import CredentialType +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderEntity, ToolProviderType +from core.tools.errors import ToolProviderNotFoundError + + +class _FakeBuiltinTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("ok") + + +class _ConcreteBuiltinProvider(BuiltinToolProviderController): + last_validation: tuple[str, dict[str, Any]] | None = None + + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): + self.last_validation = (user_id, credentials) + + +def _provider_yaml() -> dict[str, Any]: + return { + "identity": { + "author": "Dify", + "name": "fake_provider", + "label": {"en_US": "Fake Provider"}, + "description": {"en_US": "Fake description"}, + "icon": "icon.svg", + "tags": ["utilities"], + }, + "credentials_for_provider": { + "api_key": { + "type": "secret-input", + "required": True, + } + }, + "oauth_schema": { + "client_schema": [ + { + "name": "client_id", + "type": "text-input", + } + ], + "credentials_schema": [ + { + "name": "access_token", + "type": "secret-input", + } + ], + }, + } + + +def _tool_yaml() -> dict[str, Any]: + return { + "identity": { + "author": "Dify", + "name": "tool_a", + "label": {"en_US": "Tool A"}, + }, + "parameters": [], + } + + +def test_builtin_tool_provider_init_load_tools_and_basic_accessors(monkeypatch): + yaml_payloads = [_provider_yaml(), _tool_yaml()] + + def _load_yaml(*args, **kwargs): + return yaml_payloads.pop(0) + + monkeypatch.setattr("core.tools.builtin_tool.provider.load_yaml_file_cached", _load_yaml) + monkeypatch.setattr( + "core.tools.builtin_tool.provider.listdir", + lambda *args, **kwargs: ["tool_a.yaml", "__init__.py", "readme.md"], + ) + monkeypatch.setattr( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + lambda *args, **kwargs: _FakeBuiltinTool, + ) + provider = _ConcreteBuiltinProvider() + + assert provider.get_credentials_schema() + assert provider.get_tools() + assert provider.get_tool("tool_a") is not None + assert provider.get_tool("missing") is None + assert provider.provider_type == ToolProviderType.BUILT_IN + assert provider.tool_labels == ["utilities"] + assert provider.need_credentials is True + + oauth_schema = provider.get_credentials_schema_by_type(CredentialType.OAUTH2) + assert len(oauth_schema) == 1 + api_schema = provider.get_credentials_schema_by_type(CredentialType.API_KEY) + assert len(api_schema) == 1 + assert provider.get_oauth_client_schema()[0].name == "client_id" + assert set(provider.get_supported_credential_types()) == {CredentialType.API_KEY, CredentialType.OAUTH2} + + +def test_builtin_tool_provider_invalid_credential_type_raises(): + with ( + patch( + "core.tools.builtin_tool.provider.load_yaml_file_cached", + side_effect=[_provider_yaml(), _tool_yaml()], + ), + patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]), + patch( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + return_value=_FakeBuiltinTool, + ), + ): + provider = _ConcreteBuiltinProvider() + + with pytest.raises(ValueError, match="Invalid credential type: invalid"): + provider.get_credentials_schema_by_type("invalid") + + +def test_builtin_tool_provider_validate_credentials_delegates(): + with ( + patch( + "core.tools.builtin_tool.provider.load_yaml_file_cached", + side_effect=[_provider_yaml(), _tool_yaml()], + ), + patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]), + patch( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + return_value=_FakeBuiltinTool, + ), + ): + provider = _ConcreteBuiltinProvider() + + provider.validate_credentials("user-1", {"api_key": "secret"}) + assert provider.last_validation == ("user-1", {"api_key": "secret"}) + + +def test_builtin_tool_provider_unauthorized_schema_is_empty(): + with ( + patch( + "core.tools.builtin_tool.provider.load_yaml_file_cached", + side_effect=[_provider_yaml(), _tool_yaml()], + ), + patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]), + patch( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + return_value=_FakeBuiltinTool, + ), + ): + provider = _ConcreteBuiltinProvider() + + assert provider.get_credentials_schema_by_type(CredentialType.UNAUTHORIZED) == [] + + +def test_builtin_tool_provider_init_raises_when_provider_yaml_missing(): + with patch("core.tools.builtin_tool.provider.load_yaml_file_cached", side_effect=RuntimeError("boom")): + with pytest.raises(ToolProviderNotFoundError, match="can not load provider yaml"): + _ConcreteBuiltinProvider() + + +def test_builtin_tool_provider_handles_empty_credentials_and_oauth(): + provider = object.__new__(_ConcreteBuiltinProvider) + provider.tools = [] + provider.entity = ToolProviderEntity.model_validate( + { + "identity": { + "author": "Dify", + "name": "fake_provider", + "label": {"en_US": "Fake Provider"}, + "description": {"en_US": "Fake description"}, + "icon": "icon.svg", + "tags": None, + }, + "credentials_schema": [], + "oauth_schema": None, + }, + ) + + assert provider.get_oauth_client_schema() == [] + assert provider.get_supported_credential_types() == [] + assert provider.need_credentials is False + assert provider._get_tool_labels() == [] + + +def test_builtin_tool_provider_forked_tool_runtime_is_initialized(): + with ( + patch( + "core.tools.builtin_tool.provider.load_yaml_file_cached", + side_effect=[_provider_yaml(), _tool_yaml()], + ), + patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]), + patch( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + return_value=_FakeBuiltinTool, + ), + ): + provider = _ConcreteBuiltinProvider() + + tool = provider.get_tool("tool_a") + assert tool is not None + assert isinstance(tool.runtime, ToolRuntime) + assert tool.runtime.tenant_id == "" + tool.runtime.invoke_from = InvokeFrom.DEBUGGER + assert tool.runtime.invoke_from == InvokeFrom.DEBUGGER diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py new file mode 100644 index 0000000000..62cfb6ce5b --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +import math +from types import SimpleNamespace + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort +from core.tools.builtin_tool.providers.audio.audio import AudioToolProvider +from core.tools.builtin_tool.providers.audio.tools.asr import ASRTool +from core.tools.builtin_tool.providers.audio.tools.tts import TTSTool +from core.tools.builtin_tool.providers.code.code import CodeToolProvider +from core.tools.builtin_tool.providers.code.tools.simple_code import SimpleCode +from core.tools.builtin_tool.providers.time.time import WikiPediaProvider +from core.tools.builtin_tool.providers.time.tools.current_time import CurrentTimeTool +from core.tools.builtin_tool.providers.time.tools.localtime_to_timestamp import LocaltimeToTimestampTool +from core.tools.builtin_tool.providers.time.tools.timestamp_to_localtime import TimestampToLocaltimeTool +from core.tools.builtin_tool.providers.time.tools.timezone_conversion import TimezoneConversionTool +from core.tools.builtin_tool.providers.time.tools.weekday import WeekdayTool +from core.tools.builtin_tool.providers.webscraper.tools.webscraper import WebscraperTool +from core.tools.builtin_tool.providers.webscraper.webscraper import WebscraperProvider +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage +from core.tools.errors import ToolInvokeError +from dify_graph.file.enums import FileType +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey + + +def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="tool-a", + label=I18nObject(en_US="tool-a"), + provider="provider-a", + ), + parameters=[], + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + return tool_cls(provider="provider-a", entity=entity, runtime=runtime) + + +def _raise_runtime_error(*_args: object, **_kwargs: object) -> None: + raise RuntimeError("boom") + + +def test_current_time_tool(): + current_tool = _build_builtin_tool(CurrentTimeTool) + utc_text = list(current_tool.invoke(user_id="u", tool_parameters={"timezone": "UTC"}))[0].message.text + assert utc_text + + invalid_tz = list(current_tool.invoke(user_id="u", tool_parameters={"timezone": "Invalid/TZ"}))[0].message.text + assert "Invalid timezone" in invalid_tz + + +def test_localtime_to_timestamp_tool(): + localtime_tool = _build_builtin_tool(LocaltimeToTimestampTool) + ts_message = list( + localtime_tool.invoke(user_id="u", tool_parameters={"localtime": "2024-01-01 10:00:00", "timezone": "UTC"}) + )[0].message.text + ts_value = float(ts_message.strip()) + assert math.isfinite(ts_value) + assert ts_value >= 0 + with pytest.raises(ToolInvokeError): + LocaltimeToTimestampTool.localtime_to_timestamp("bad", "%Y-%m-%d %H:%M:%S", "UTC") + + +def test_timestamp_to_localtime_tool(): + to_local_tool = _build_builtin_tool(TimestampToLocaltimeTool) + local_text = list(to_local_tool.invoke(user_id="u", tool_parameters={"timestamp": 1704067200, "timezone": "UTC"}))[ + 0 + ].message.text + assert "2024" in local_text + with pytest.raises(ToolInvokeError): + TimestampToLocaltimeTool.timestamp_to_localtime("bad", "UTC") # type: ignore[arg-type] + + +def test_timezone_conversion_tool(): + timezone_tool = _build_builtin_tool(TimezoneConversionTool) + converted = list( + timezone_tool.invoke( + user_id="u", + tool_parameters={ + "current_time": "2024-01-01 08:00:00", + "current_timezone": "UTC", + "target_timezone": "Asia/Tokyo", + }, + ) + )[0].message.text + assert converted.startswith("2024-01-01") + with pytest.raises(ToolInvokeError): + TimezoneConversionTool.timezone_convert("bad", "UTC", "Asia/Tokyo") + + +def test_weekday_tool(): + weekday_tool = _build_builtin_tool(WeekdayTool) + valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text + assert "January 1, 2024" in valid + invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[ + 0 + ].message.text + assert "Invalid date" in invalid + with pytest.raises(ValueError, match="Month is required"): + list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "day": 1})) + + +def test_simple_code_valid_execution(monkeypatch): + simple_code = _build_builtin_tool(SimpleCode) + + monkeypatch.setattr( + "core.tools.builtin_tool.providers.code.tools.simple_code.CodeExecutor.execute_code", + lambda *a: "ok", + ) + result = list( + simple_code.invoke( + user_id="u", + tool_parameters={"language": "python3", "code": "print(1)"}, + ) + )[0].message.text + assert result == "ok" + + +def test_simple_code_invalid_language(): + simple_code = _build_builtin_tool(SimpleCode) + + with pytest.raises(ValueError, match="Only python3 and javascript"): + list(simple_code.invoke(user_id="u", tool_parameters={"language": "go", "code": "fmt.Println(1)"})) + + +def test_simple_code_execution_error(monkeypatch): + simple_code = _build_builtin_tool(SimpleCode) + + monkeypatch.setattr( + "core.tools.builtin_tool.providers.code.tools.simple_code.CodeExecutor.execute_code", + _raise_runtime_error, + ) + with pytest.raises(ToolInvokeError, match="boom"): + list(simple_code.invoke(user_id="u", tool_parameters={"language": "python3", "code": "print(1)"})) + + +def test_webscraper_empty_url(): + webscraper = _build_builtin_tool(WebscraperTool) + empty = list(webscraper.invoke(user_id="u", tool_parameters={"url": ""}))[0].message.text + assert empty == "Please input url" + + +def test_webscraper_fetch(monkeypatch): + webscraper = _build_builtin_tool(WebscraperTool) + monkeypatch.setattr("core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url", lambda *a, **k: "page") + full = list(webscraper.invoke(user_id="u", tool_parameters={"url": "https://example.com"}))[0].message.text + assert full == "page" + + +def test_webscraper_summary(monkeypatch): + webscraper = _build_builtin_tool(WebscraperTool) + monkeypatch.setattr("core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url", lambda *a, **k: "page") + monkeypatch.setattr(webscraper, "summary", lambda user_id, content: "summary") + summarized = list( + webscraper.invoke( + user_id="u", + tool_parameters={"url": "https://example.com", "generate_summary": True}, + ) + )[0].message.text + assert summarized == "summary" + + +def test_webscraper_fetch_error(monkeypatch): + webscraper = _build_builtin_tool(WebscraperTool) + monkeypatch.setattr( + "core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url", + _raise_runtime_error, + ) + with pytest.raises(ToolInvokeError, match="boom"): + list(webscraper.invoke(user_id="u", tool_parameters={"url": "https://example.com"})) + + +def test_asr_invalid_file(): + asr = _build_builtin_tool(ASRTool) + file_obj = SimpleNamespace(type=FileType.DOCUMENT) + invalid_file = list(asr.invoke(user_id="u", tool_parameters={"audio_file": file_obj}))[0].message.text + assert "not a valid audio file" in invalid_file + + +def test_asr_valid_file_invocation(monkeypatch): + asr = _build_builtin_tool(ASRTool) + model_instance = type("M", (), {"invoke_speech2text": lambda self, file, user: "transcript"})() + model_manager = type("Mgr", (), {"get_model_instance": lambda *a, **k: model_instance})() + monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.download", lambda file: b"audio-bytes") + monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.ModelManager", lambda: model_manager) + audio_file = SimpleNamespace(type=FileType.AUDIO) + ok = list(asr.invoke(user_id="u", tool_parameters={"audio_file": audio_file, "model": "p#m"}))[0].message.text + assert ok == "transcript" + + +def test_asr_available_models_and_runtime_parameters(monkeypatch): + asr = _build_builtin_tool(ASRTool) + provider_model = type("PM", (), {"provider": "p", "models": [type("Model", (), {"model": "m"})()]})() + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.asr.ModelProviderService.get_models_by_model_type", + lambda *a, **k: [provider_model], + ) + assert asr.get_available_models() == [("p", "m")] + assert asr.get_runtime_parameters()[0].name == "model" + + +def test_tts_invoke_returns_messages(monkeypatch): + tts = _build_builtin_tool(TTSTool) + voices_model_instance = type( + "TTSM", + (), + { + "get_tts_voices": lambda self: [{"value": "voice-1"}], + "invoke_tts": lambda self, **kwargs: [b"a", b"b"], + }, + )() + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", + lambda: type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})(), + ) + messages = list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) + assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.BLOB] + + +def test_tts_get_available_models_requires_runtime(): + tts = _build_builtin_tool(TTSTool) + tts.runtime = None + with pytest.raises(ValueError, match="Runtime is required"): + tts.get_available_models() + + +def test_tts_tool_raises_when_runtime_missing(): + tts = _build_builtin_tool(TTSTool) + tts.runtime = None + with pytest.raises(ValueError, match="Runtime is required"): + list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) + + +@pytest.mark.parametrize( + "voices", + [[{"value": None}], []], +) +def test_tts_tool_raises_when_voice_unavailable(monkeypatch, voices): + tts = _build_builtin_tool(TTSTool) + tts.runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + model_without_voice = type( + "TTSModelNoVoice", + (), + { + "get_tts_voices": lambda self: voices, + "invoke_tts": lambda self, **kwargs: [b"x"], + }, + )() + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", + lambda: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), + ) + with pytest.raises(ValueError, match="no voice available"): + list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) + + +def test_tts_tool_get_available_models_and_runtime_parameters(monkeypatch): + tts = _build_builtin_tool(TTSTool) + + model_1 = SimpleNamespace( + model="model-a", + model_properties={ModelPropertyKey.VOICES: [{"mode": "v1", "name": "Voice 1"}]}, + ) + model_2 = SimpleNamespace(model="model-b", model_properties={}) + provider_models = [SimpleNamespace(provider="provider-a", models=[model_1, model_2])] + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.tts.ModelProviderService.get_models_by_model_type", + lambda *args, **kwargs: provider_models, + ) + + available_models = tts.get_available_models() + assert available_models == [ + ("provider-a", "model-a", [{"mode": "v1", "name": "Voice 1"}]), + ("provider-a", "model-b", []), + ] + + runtime_parameters = tts.get_runtime_parameters() + assert runtime_parameters[0].name == "model" + assert runtime_parameters[0].required is True + assert runtime_parameters[0].options[0].value == "provider-a#model-a" + assert runtime_parameters[1].name == "voice#provider-a#model-a" + + +def test_provider_classes_and_builtin_sort(monkeypatch): + # Use object.__new__ to avoid YAML-loading __init__; only pass-through validation is exercised. + # Ensure pass-through _validate_credentials methods are executed. + AudioToolProvider._validate_credentials(object.__new__(AudioToolProvider), "u", {}) + CodeToolProvider._validate_credentials(object.__new__(CodeToolProvider), "u", {}) + WikiPediaProvider._validate_credentials(object.__new__(WikiPediaProvider), "u", {}) + WebscraperProvider._validate_credentials(object.__new__(WebscraperProvider), "u", {}) + + providers = [SimpleNamespace(name="b"), SimpleNamespace(name="a")] + monkeypatch.setattr(BuiltinToolProviderSort, "_position", {}) + monkeypatch.setattr( + "core.tools.builtin_tool.providers._positions.get_tool_position_map", + lambda _: {"a": 0, "b": 1}, + ) + monkeypatch.setattr( + "core.tools.builtin_tool.providers._positions.sort_by_position_map", + lambda position, values, name_func: sorted(values, key=lambda x: name_func(x)), + ) + sorted_providers = BuiltinToolProviderSort.sort(providers) + assert [p.name for p in sorted_providers] == ["a", "b"] diff --git a/api/tests/unit_tests/core/tools/test_custom_tool.py b/api/tests/unit_tests/core/tools/test_custom_tool.py new file mode 100644 index 0000000000..79b8eaaa87 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_custom_tool.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import httpx +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.custom_tool.tool import ApiTool, ParsedResponse +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage +from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError + + +def _build_tool(*, openapi: dict | None = None) -> ApiTool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="tool-a", + label=I18nObject(en_US="tool-a"), + provider="provider-a", + ), + parameters=[], + ) + bundle = ApiToolBundle( + server_url="https://api.example.com/items/{id}", + method="GET", + summary="summary", + operation_id="op-id", + parameters=[], + author="author", + openapi=openapi or {"parameters": []}, + ) + runtime = ToolRuntime( + tenant_id="tenant-1", + invoke_from=InvokeFrom.DEBUGGER, + credentials={"auth_type": "api_key_header", "api_key_value": "k"}, + ) + return ApiTool(entity=entity, api_bundle=bundle, runtime=runtime, provider_id="provider-id") + + +def test_parsed_response_to_string(): + assert ParsedResponse({"a": 1}, True).to_string() == '{"a": 1}' + assert ParsedResponse("ok", False).to_string() == "ok" + + +def test_api_tool_fork_runtime_and_validate_credentials(monkeypatch): + tool = _build_tool() + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + assert isinstance(forked, ApiTool) + assert forked.runtime.tenant_id == "tenant-2" + + tool.api_bundle = None # type: ignore[assignment] + with pytest.raises(ValueError, match="api_bundle is required"): + tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + + tool = _build_tool() + assert tool.validate_credentials(credentials={}, parameters={}, format_only=True) == "" + monkeypatch.setattr(tool, "assembling_request", lambda parameters: {"Authorization": "Bearer x"}) + monkeypatch.setattr( + tool, + "do_http_request", + lambda url, method, headers, parameters: httpx.Response(200, json={"ok": True}), + ) + result = tool.validate_credentials(credentials={}, parameters={"a": 1}, format_only=False) + assert result == '{"ok": true}' + + +def test_assembling_request_auth_header_assembly(): + tool = _build_tool() + + headers = tool.assembling_request(parameters={}) + assert headers["Authorization"] == "k" + + tool.runtime.credentials = { + "auth_type": "api_key_header", + "api_key_header_prefix": "bearer", + "api_key_value": "abc", + } + headers = tool.assembling_request(parameters={}) + assert headers["Authorization"] == "Bearer abc" + + tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_header_prefix": "basic", "api_key_value": "abc"} + headers = tool.assembling_request(parameters={}) + assert headers["Authorization"] == "Basic abc" + + tool.runtime.credentials = {"auth_type": "api_key_query", "api_key_value": "abc"} + assert tool.assembling_request(parameters={}) == {} + + +def test_assembling_request_runtime_auth_errors(): + tool = _build_tool() + + tool.runtime = None + with pytest.raises(ToolProviderCredentialValidationError, match="runtime not initialized"): + tool.assembling_request(parameters={}) + + tool.runtime = ToolRuntime(tenant_id="tenant", credentials={}) + with pytest.raises(ToolProviderCredentialValidationError, match="Missing auth_type"): + tool.assembling_request(parameters={}) + + tool.runtime.credentials = {"auth_type": "api_key_header"} + with pytest.raises(ToolProviderCredentialValidationError, match="Missing api_key_value"): + tool.assembling_request(parameters={}) + + tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": 123} + with pytest.raises(ToolProviderCredentialValidationError, match="must be a string"): + tool.assembling_request(parameters={}) + + +def test_assembling_request_parameter_validation_and_defaults(): + tool = _build_tool() + + tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": "x"} + tool.api_bundle.parameters = [ + SimpleNamespace(required=True, name="required_param", default=None), + ] + with pytest.raises(ToolParameterValidationError, match="Missing required parameter required_param"): + tool.assembling_request(parameters={}) + + tool.api_bundle.parameters = [ + SimpleNamespace(required=True, name="required_param", default="d"), + ] + params = {} + tool.assembling_request(parameters=params) + assert params["required_param"] == "d" + + +def test_validate_and_parse_response_branches(): + tool = _build_tool() + + with pytest.raises(ToolInvokeError, match="status code 500"): + tool.validate_and_parse_response(httpx.Response(500, text="boom")) + + empty = tool.validate_and_parse_response(httpx.Response(200, content=b"")) + assert empty.is_json is False + assert "Empty response from the tool" in str(empty.content) + + json_resp = tool.validate_and_parse_response( + httpx.Response(200, json={"a": 1}, headers={"content-type": "application/json"}) + ) + assert json_resp.is_json is True + assert json_resp.content == {"a": 1} + + non_json_type = tool.validate_and_parse_response( + httpx.Response(200, text='{"a": 1}', headers={"content-type": "text/plain"}) + ) + assert non_json_type.is_json is False + assert non_json_type.content == '{"a": 1}' + + plain_resp = tool.validate_and_parse_response(httpx.Response(200, text="plain")) + assert plain_resp.is_json is False + assert plain_resp.content == "plain" + + with pytest.raises(ValueError, match="Invalid response type"): + tool.validate_and_parse_response("invalid") # type: ignore[arg-type] + + +def test_get_parameter_value_and_type_conversion_helpers(): + tool = _build_tool() + + assert tool.get_parameter_value({"name": "x"}, {"x": 1}) == 1 + assert tool.get_parameter_value({"name": "x", "required": False, "schema": {"default": "d"}}, {}) == "d" + with pytest.raises(ToolParameterValidationError, match="Missing required parameter x"): + tool.get_parameter_value({"name": "x", "required": True}, {}) + + assert tool._convert_body_property_any_of({}, "12", [{"type": "integer"}]) == 12 + assert tool._convert_body_property_any_of({}, "1.5", [{"type": "number"}]) == 1.5 + assert tool._convert_body_property_any_of({}, "true", [{"type": "boolean"}]) is True + assert tool._convert_body_property_any_of({}, "", [{"type": "null"}]) is None + assert tool._convert_body_property_any_of({}, "x", [{"anyOf": [{"type": "string"}]}]) == "x" + + assert tool._convert_body_property_type({"type": "integer"}, "1") == 1 + assert tool._convert_body_property_type({"type": "number"}, "1.2") == 1.2 + assert tool._convert_body_property_type({"type": "string"}, 1) == "1" + assert tool._convert_body_property_type({"type": "boolean"}, 1) is True + assert tool._convert_body_property_type({"type": "null"}, None) is None + assert tool._convert_body_property_type({"type": "object"}, '{"a":1}') == {"a": 1} + assert tool._convert_body_property_type({"type": "array"}, "[1,2]") == [1, 2] + assert tool._convert_body_property_type({"type": "invalid"}, "v") == "v" + assert tool._convert_body_property_type({"anyOf": [{"type": "integer"}]}, "2") == 2 + + +def test_do_http_request_builds_arguments_and_handles_invalid_method(monkeypatch): + openapi = { + "parameters": [ + {"name": "id", "in": "path", "required": True, "schema": {"type": "string"}}, + {"name": "q", "in": "query", "required": False, "schema": {"default": ""}}, + {"name": "X-Extra", "in": "header", "required": False, "schema": {"default": "x"}}, + {"name": "sid", "in": "cookie", "required": False, "schema": {"default": "cookie1"}}, + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["count"], + "properties": { + "count": {"type": "integer"}, + "name": {"type": "string", "default": "n"}, + }, + } + } + } + }, + } + tool = _build_tool(openapi=openapi) + tool.runtime.credentials = {"auth_type": "api_key_query", "api_key_query_param": "key", "api_key_value": "v"} + headers = {} + captured = {} + + def _fake_get(url, **kwargs): + captured["url"] = url + captured["kwargs"] = kwargs + return httpx.Response(200, text="ok") + + monkeypatch.setattr("core.tools.custom_tool.tool.ssrf_proxy.get", _fake_get) + response = tool.do_http_request( + "https://api.example.com/items/{id}", + "GET", + headers=headers, + parameters={"id": "123", "count": "2", "q": "search"}, + ) + + assert isinstance(response, httpx.Response) + assert captured["url"].endswith("/items/123") + assert captured["kwargs"]["params"]["q"] == "search" + assert captured["kwargs"]["params"]["key"] == "v" + assert captured["kwargs"]["headers"]["Content-Type"] == "application/json" + + invalid_method_tool = _build_tool(openapi={"parameters": []}) + with pytest.raises(ValueError, match="Invalid http method"): + invalid_method_tool.do_http_request("https://api.example.com", "TRACE", headers={}, parameters={}) + + +def test_do_http_request_handles_file_upload_and_invoke_paths(monkeypatch): + openapi = { + "parameters": [], + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": {"file": {"format": "binary"}}, + } + } + } + }, + } + tool = _build_tool(openapi=openapi) + tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": "k"} + fake_file = SimpleNamespace(filename="a.txt", mime_type="text/plain") + captured = {} + + def _fake_post(url, **kwargs): + captured["headers"] = kwargs["headers"] + captured["files"] = kwargs["files"] + return httpx.Response(200, text="ok") + + monkeypatch.setattr("core.tools.custom_tool.tool.download", lambda _: b"file-bytes") + monkeypatch.setattr("core.tools.custom_tool.tool.ssrf_proxy.post", _fake_post) + response = tool.do_http_request( + "https://api.example.com/upload", + "POST", + headers={}, + parameters={"file": fake_file}, + ) + assert isinstance(response, httpx.Response) + assert "Content-Type" not in captured["headers"] + assert captured["files"][0][0] == "file" + + # _invoke JSON path + monkeypatch.setattr(tool, "assembling_request", lambda parameters: {}) + monkeypatch.setattr(tool, "do_http_request", lambda *args, **kwargs: httpx.Response(200, text='{"a":1}')) + monkeypatch.setattr(tool, "validate_and_parse_response", lambda _: ParsedResponse({"a": 1}, True)) + messages = list(tool.invoke(user_id="u1", tool_parameters={})) + assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.JSON, ToolInvokeMessage.MessageType.TEXT] + + # _invoke text path + monkeypatch.setattr(tool, "validate_and_parse_response", lambda _: ParsedResponse("plain", False)) + messages = list(tool.invoke(user_id="u1", tool_parameters={})) + assert len(messages) == 1 + assert messages[0].message.text == "plain" diff --git a/api/tests/unit_tests/core/tools/test_custom_tool_provider.py b/api/tests/unit_tests/core/tools/test_custom_tool_provider.py new file mode 100644 index 0000000000..93ae217e24 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_custom_tool_provider.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.tools.custom_tool.provider import ApiToolProviderController +from core.tools.custom_tool.tool import ApiTool +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ApiProviderAuthType, ToolProviderType + + +def _db_provider() -> SimpleNamespace: + bundle = ApiToolBundle( + server_url="https://api.example.com/items", + method="GET", + summary="List items", + operation_id="list_items", + parameters=[], + author="author", + openapi={"parameters": []}, + ) + return SimpleNamespace( + id="provider-id", + tenant_id="tenant-1", + name="provider-a", + description="desc", + icon="icon.svg", + user=SimpleNamespace(name="Alice"), + tools=[bundle], + ) + + +def test_api_tool_provider_from_db_and_parse_tool_bundle(): + controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.API_KEY_HEADER) + assert controller.provider_type == ToolProviderType.API + assert any(c.name == "api_key_value" for c in controller.entity.credentials_schema) + + tool = controller._parse_tool_bundle(_db_provider().tools[0]) + assert isinstance(tool, ApiTool) + assert tool.entity.identity.provider == "provider-id" + + +def test_api_tool_provider_from_db_query_auth_and_none_auth(): + query_controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.API_KEY_QUERY) + assert any(c.name == "api_key_query_param" for c in query_controller.entity.credentials_schema) + + none_controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.NONE) + assert [c.name for c in none_controller.entity.credentials_schema] == ["auth_type"] + + +def test_api_tool_provider_load_get_tools_and_get_tool(): + controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.NONE) + loaded = controller.load_bundled_tools(_db_provider().tools) + assert len(loaded) == 1 + + assert isinstance(controller.get_tool("list_items"), ApiTool) + + with pytest.raises(ValueError, match="not found"): + controller.get_tool("missing") + + # Return cached tools without querying database. + cached = controller.get_tools("tenant-1") + assert len(cached) == 1 + + # Force DB fetch branch. + controller.tools = [] + provider_with_tools = _db_provider() + with patch("core.tools.custom_tool.provider.db") as mock_db: + scalars_result = Mock() + scalars_result.all.return_value = [provider_with_tools] + mock_db.session.scalars.return_value = scalars_result + tools = controller.get_tools("tenant-1") + assert len(tools) == 1 diff --git a/api/tests/unit_tests/core/tools/test_dataset_retriever_tool.py b/api/tests/unit_tests/core/tools/test_dataset_retriever_tool.py new file mode 100644 index 0000000000..23c0be9487 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_dataset_retriever_tool.py @@ -0,0 +1,145 @@ +"""Unit tests for DatasetRetrieverTool behavior and retrieval wiring.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool + + +def _retrieve_config() -> DatasetRetrieveConfigEntity: + return DatasetRetrieveConfigEntity(retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE) + + +def test_get_dataset_tools_returns_empty_for_empty_dataset_ids() -> None: + # Arrange + retrieve_config = _retrieve_config() + + # Act + tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id="tenant", + dataset_ids=[], + retrieve_config=retrieve_config, + return_resource=False, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="u", + inputs={}, + ) + + # Assert + assert tools == [] + + +def test_get_dataset_tools_returns_empty_for_missing_retrieve_config() -> None: + # Arrange + dataset_ids = ["d1"] + + # Act + tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id="tenant", + dataset_ids=dataset_ids, + retrieve_config=None, # type: ignore[arg-type] + return_resource=False, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="u", + inputs={}, + ) + + # Assert + assert tools == [] + + +def test_get_dataset_tools_builds_tool_and_restores_strategy() -> None: + # Arrange + retrieve_config = _retrieve_config() + retrieval_tool = SimpleNamespace(name="dataset_tool", description="desc", run=lambda query: f"result:{query}") + feature = Mock() + feature.to_dataset_retriever_tool.return_value = [retrieval_tool] + + # Act + with patch("core.tools.utils.dataset_retriever_tool.DatasetRetrieval", return_value=feature): + tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id="tenant", + dataset_ids=["d1"], + retrieve_config=retrieve_config, + return_resource=True, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="u", + inputs={"x": 1}, + ) + + # Assert + assert len(tools) == 1 + assert tools[0].entity.identity.name == "dataset_tool" + assert retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE + + +def _build_dataset_tool() -> tuple[DatasetRetrieverTool, SimpleNamespace]: + retrieval_tool = SimpleNamespace(name="dataset_tool", description="desc", run=lambda query: f"result:{query}") + feature = Mock() + feature.to_dataset_retriever_tool.return_value = [retrieval_tool] + with patch("core.tools.utils.dataset_retriever_tool.DatasetRetrieval", return_value=feature): + tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id="tenant", + dataset_ids=["d1"], + retrieve_config=_retrieve_config(), + return_resource=False, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="u", + inputs={}, + ) + return tools[0], retrieval_tool + + +def test_runtime_parameters_shape() -> None: + # Arrange + tool, _ = _build_dataset_tool() + + # Act + params = tool.get_runtime_parameters() + + # Assert + assert len(params) == 1 + assert params[0].name == "query" + + +def test_empty_query_behavior() -> None: + # Arrange + tool, _ = _build_dataset_tool() + + # Act + empty_query = list(tool.invoke(user_id="u", tool_parameters={})) + + # Assert + assert len(empty_query) == 1 + assert empty_query[0].message.text == "please input query" + + +def test_query_invocation_result() -> None: + # Arrange + tool, _ = _build_dataset_tool() + + # Act + result = list(tool.invoke(user_id="u", tool_parameters={"query": "hello"})) + + # Assert + assert len(result) == 1 + assert result[0].message.text == "result:hello" + + +def test_validate_credentials() -> None: + # Arrange + tool, _ = _build_dataset_tool() + + # Act + result = tool.validate_credentials(credentials={}, parameters={}, format_only=False) + + # Assert + assert result is None diff --git a/api/tests/unit_tests/core/tools/test_mcp_tool.py b/api/tests/unit_tests/core/tools/test_mcp_tool.py new file mode 100644 index 0000000000..eaf054de59 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_mcp_tool.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import base64 +from unittest.mock import patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.mcp.types import ( + BlobResourceContents, + CallToolResult, + EmbeddedResource, + ImageContent, + TextContent, + TextResourceContents, +) +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.mcp_tool.tool import MCPTool + + +def _build_mcp_tool(*, with_output_schema: bool = True) -> MCPTool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="remote-tool", + label=I18nObject(en_US="remote-tool"), + provider="provider-id", + ), + parameters=[], + output_schema={"type": "object"} if with_output_schema else {}, + ) + return MCPTool( + entity=entity, + runtime=ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER), + tenant_id="tenant-1", + icon="icon.svg", + server_url="https://mcp.example.com", + provider_id="provider-id", + headers={"x-auth": "token"}, + ) + + +def test_mcp_tool_provider_type_and_fork_runtime(): + tool = _build_mcp_tool() + assert tool.tool_provider_type() == ToolProviderType.MCP + + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + assert isinstance(forked, MCPTool) + assert forked.runtime.tenant_id == "tenant-2" + assert forked.provider_id == "provider-id" + + +def test_mcp_tool_text_and_json_processing_helpers(): + tool = _build_mcp_tool() + + json_messages = list(tool._process_text_content(TextContent(type="text", text='{"a": 1}'))) + assert json_messages[0].type == ToolInvokeMessage.MessageType.JSON + + plain_messages = list(tool._process_text_content(TextContent(type="text", text="not-json"))) + assert plain_messages[0].type == ToolInvokeMessage.MessageType.TEXT + assert plain_messages[0].message.text == "not-json" + + list_messages = list(tool._process_json_content([{"k": 1}, {"k": 2}])) + assert [m.type for m in list_messages] == [ToolInvokeMessage.MessageType.JSON, ToolInvokeMessage.MessageType.JSON] + + mixed_list_messages = list(tool._process_json_list([{"k": 1}, 2])) + assert len(mixed_list_messages) == 1 + assert mixed_list_messages[0].type == ToolInvokeMessage.MessageType.TEXT + + primitive_messages = list(tool._process_json_content(123)) + assert primitive_messages[0].message.text == "123" + + +def test_mcp_tool_usage_extraction_helpers(): + usage = MCPTool._extract_usage_dict({"usage": {"total_tokens": 9}}) + assert usage == {"total_tokens": 9} + + usage = MCPTool._extract_usage_dict({"metadata": {"usage": {"prompt_tokens": 3, "completion_tokens": 2}}}) + assert usage == {"prompt_tokens": 3, "completion_tokens": 2} + + usage = MCPTool._extract_usage_dict({"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}) + assert usage == {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3} + + usage = MCPTool._extract_usage_dict({"nested": [{"deep": {"usage": {"total_tokens": 7}}}]}) + assert usage == {"total_tokens": 7} + + result_with_usage = CallToolResult(content=[], _meta={"usage": {"prompt_tokens": 1, "completion_tokens": 2}}) + derived = MCPTool._derive_usage_from_result(result_with_usage) + assert derived.prompt_tokens == 1 + assert derived.completion_tokens == 2 + + result_without_usage = CallToolResult(content=[], _meta=None) + derived = MCPTool._derive_usage_from_result(result_without_usage) + assert derived.total_tokens == 0 + + +def test_mcp_tool_invoke_handles_content_types_and_structured_output(): + tool = _build_mcp_tool() + img_data = base64.b64encode(b"img").decode() + blob_data = base64.b64encode(b"blob").decode() + result = CallToolResult( + content=[ + TextContent(type="text", text='{"a": 1}'), + ImageContent(type="image", data=img_data, mimeType="image/png"), + EmbeddedResource( + type="resource", + resource=TextResourceContents(uri="file:///tmp/a.txt", text="embedded-text"), + ), + EmbeddedResource( + type="resource", + resource=BlobResourceContents( + uri="file:///tmp/b.bin", + blob=blob_data, + mimeType="application/octet-stream", + ), + ), + ], + structuredContent={"x": 1}, + _meta={"usage": {"prompt_tokens": 2, "completion_tokens": 3}}, + ) + + with patch.object(MCPTool, "invoke_remote_mcp_tool", return_value=result): + messages = list(tool.invoke(user_id="user-1", tool_parameters={"a": 1})) + + types = [m.type for m in messages] + assert ToolInvokeMessage.MessageType.JSON in types + assert ToolInvokeMessage.MessageType.BLOB in types + assert ToolInvokeMessage.MessageType.TEXT in types + assert ToolInvokeMessage.MessageType.VARIABLE in types + assert tool.latest_usage.total_tokens == 5 + + +def test_mcp_tool_invoke_raises_for_unsupported_embedded_resource(): + tool = _build_mcp_tool() + # Use model_construct to bypass pydantic validation and force unsupported resource path. + bad_resource = EmbeddedResource.model_construct(type="resource", resource=object()) + result = CallToolResult(content=[bad_resource], _meta=None) + + with patch.object(MCPTool, "invoke_remote_mcp_tool", return_value=result): + with pytest.raises(ToolInvokeError, match="Unsupported embedded resource type"): + list(tool.invoke(user_id="user-1", tool_parameters={})) + + +def test_mcp_tool_handle_none_parameter_filters_empty_values(): + tool = _build_mcp_tool() + cleaned = tool._handle_none_parameter({"a": 1, "b": None, "c": "", "d": " ", "e": "ok"}) + assert cleaned == {"a": 1, "e": "ok"} diff --git a/api/tests/unit_tests/core/tools/test_mcp_tool_provider.py b/api/tests/unit_tests/core/tools/test_mcp_tool_provider.py new file mode 100644 index 0000000000..1060d19ab1 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_mcp_tool_provider.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from datetime import datetime +from unittest.mock import Mock, patch + +import pytest + +from core.entities.mcp_provider import MCPProviderEntity +from core.tools.entities.tool_entities import ToolProviderType +from core.tools.mcp_tool.provider import MCPToolProviderController +from core.tools.mcp_tool.tool import MCPTool + + +def _build_mcp_entity(*, icon: str = "icon.svg") -> MCPProviderEntity: + now = datetime.now() + return MCPProviderEntity( + id="db-id", + provider_id="provider-id", + name="MCP Provider", + tenant_id="tenant-1", + user_id="user-1", + server_url="https://mcp.example.com", + headers={}, + timeout=30, + sse_read_timeout=300, + authed=False, + credentials={}, + tools=[ + { + "name": "remote-tool", + "description": "remote tool", + "inputSchema": {}, + "outputSchema": {"type": "object"}, + } + ], + icon=icon, + created_at=now, + updated_at=now, + ) + + +def test_mcp_tool_provider_controller_from_entity_and_get_tools(): + entity = _build_mcp_entity() + with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]): + controller = MCPToolProviderController.from_entity(entity) + + assert controller.provider_type == ToolProviderType.MCP + tool = controller.get_tool("remote-tool") + assert isinstance(tool, MCPTool) + assert tool.tenant_id == "tenant-1" + + tools = controller.get_tools() + assert len(tools) == 1 + assert isinstance(tools[0], MCPTool) + + with pytest.raises(ValueError, match="not found"): + controller.get_tool("missing") + + +def test_mcp_tool_provider_controller_from_entity_requires_icon(): + entity = _build_mcp_entity(icon="") + with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]): + with pytest.raises(ValueError, match="icon is required"): + MCPToolProviderController.from_entity(entity) + + +def test_mcp_tool_provider_controller_from_db_delegates_to_entity(): + entity = _build_mcp_entity() + db_provider = Mock() + db_provider.to_entity.return_value = entity + with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]): + controller = MCPToolProviderController.from_db(db_provider) + assert isinstance(controller, MCPToolProviderController) diff --git a/api/tests/unit_tests/core/tools/test_plugin_tool.py b/api/tests/unit_tests/core/tools/test_plugin_tool.py new file mode 100644 index 0000000000..4378432a0f --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_plugin_tool.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolParameter +from core.tools.plugin_tool.tool import PluginTool + + +def _build_plugin_tool(*, has_runtime_parameters: bool) -> PluginTool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="tool-a", + label=I18nObject(en_US="tool-a"), + provider="provider-a", + ), + parameters=[ + ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + ], + has_runtime_parameters=has_runtime_parameters, + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER, credentials={"api_key": "x"}) + return PluginTool( + entity=entity, + runtime=runtime, + tenant_id="tenant-1", + icon="icon.svg", + plugin_unique_identifier="plugin-uid", + ) + + +def test_plugin_tool_invoke_and_fork_runtime(): + tool = _build_plugin_tool(has_runtime_parameters=False) + manager = Mock() + manager.invoke.return_value = iter([tool.create_text_message("ok")]) + + with patch("core.tools.plugin_tool.tool.PluginToolManager", return_value=manager): + with patch( + "core.tools.plugin_tool.tool.convert_parameters_to_plugin_format", + return_value={"converted": 1}, + ): + messages = list(tool.invoke(user_id="user-1", tool_parameters={"raw": 1})) + + assert [m.message.text for m in messages] == ["ok"] + manager.invoke.assert_called_once() + assert manager.invoke.call_args.kwargs["tool_parameters"] == {"converted": 1} + + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + assert isinstance(forked, PluginTool) + assert forked.runtime.tenant_id == "tenant-2" + assert forked.plugin_unique_identifier == "plugin-uid" + + +def test_plugin_tool_get_runtime_parameters_branches(): + tool = _build_plugin_tool(has_runtime_parameters=False) + assert tool.get_runtime_parameters() == tool.entity.parameters + + tool = _build_plugin_tool(has_runtime_parameters=True) + cached = [ + ToolParameter.get_simple_instance( + name="k", + llm_description="k", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + ] + tool.runtime_parameters = cached + assert tool.get_runtime_parameters() == cached + + tool.runtime_parameters = None + manager = Mock() + returned = [ + ToolParameter.get_simple_instance( + name="dyn", + llm_description="dyn", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + ] + manager.get_runtime_parameters.return_value = returned + with patch("core.tools.plugin_tool.tool.PluginToolManager", return_value=manager): + assert tool.get_runtime_parameters(conversation_id="c1", app_id="a1", message_id="m1") == returned + assert tool.runtime_parameters == returned diff --git a/api/tests/unit_tests/core/tools/test_plugin_tool_provider.py b/api/tests/unit_tests/core/tools/test_plugin_tool_provider.py new file mode 100644 index 0000000000..5ef03cc6ca --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_plugin_tool_provider.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolProviderEntityWithPlugin, + ToolProviderIdentity, + ToolProviderType, +) +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.plugin_tool.provider import PluginToolProviderController +from core.tools.plugin_tool.tool import PluginTool + + +def _build_controller() -> PluginToolProviderController: + tool_entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="tool-a", + label=I18nObject(en_US="tool-a"), + provider="provider-a", + ), + parameters=[], + ) + entity = ToolProviderEntityWithPlugin( + identity=ToolProviderIdentity( + author="author", + name="provider-a", + description=I18nObject(en_US="desc"), + icon="icon.svg", + label=I18nObject(en_US="Provider"), + ), + credentials_schema=[], + plugin_id="plugin-id", + tools=[tool_entity], + ) + return PluginToolProviderController( + entity=entity, + plugin_id="plugin-id", + plugin_unique_identifier="plugin-uid", + tenant_id="tenant-1", + ) + + +def test_plugin_tool_provider_controller_basic_behaviors(): + controller = _build_controller() + assert controller.provider_type == ToolProviderType.PLUGIN + + tool = controller.get_tool("tool-a") + assert isinstance(tool, PluginTool) + assert tool.runtime.tenant_id == "tenant-1" + + tools = controller.get_tools() + assert len(tools) == 1 + assert isinstance(tools[0], PluginTool) + + with pytest.raises(ValueError, match="not found"): + controller.get_tool("missing") + + +def test_validate_credentials_success(): + controller = _build_controller() + manager = Mock() + manager.validate_provider_credentials.return_value = True + + with patch("core.tools.plugin_tool.provider.PluginToolManager", return_value=manager): + controller._validate_credentials(user_id="u1", credentials={"api_key": "x"}) + + manager.validate_provider_credentials.assert_called_once_with( + tenant_id="tenant-1", + user_id="u1", + provider="provider-a", + credentials={"api_key": "x"}, + ) + + +def test_validate_credentials_failure(): + controller = _build_controller() + manager = Mock() + manager.validate_provider_credentials.return_value = False + + with patch("core.tools.plugin_tool.provider.PluginToolManager", return_value=manager): + with pytest.raises(ToolProviderCredentialValidationError, match="Invalid credentials"): + controller._validate_credentials(user_id="u1", credentials={"api_key": "x"}) diff --git a/api/tests/unit_tests/core/tools/test_signature.py b/api/tests/unit_tests/core/tools/test_signature.py new file mode 100644 index 0000000000..a5242a78c5 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_signature.py @@ -0,0 +1,119 @@ +"""Unit tests for core.tools.signature covering signing and verification invariants.""" + +from __future__ import annotations + +from urllib.parse import parse_qs, urlparse + +import pytest + +from core.tools.signature import sign_tool_file, sign_upload_file, verify_tool_file_signature + + +def test_sign_tool_file_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x01" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 120) + + url = sign_tool_file("tool-file-id", ".png", for_external=False) + parsed = urlparse(url) + query = parse_qs(parsed.query) + timestamp = query["timestamp"][0] + nonce = query["nonce"][0] + sign = query["sign"][0] + + assert parsed.scheme == "https" + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/tools/tool-file-id.png" + assert verify_tool_file_signature("tool-file-id", timestamp, nonce, sign) is True + + +def test_sign_tool_file_for_external_uses_files_url(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x04" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 120) + + url = sign_tool_file("tool-file-id", ".png", for_external=True) + parsed = urlparse(url) + + assert parsed.scheme == "https" + assert parsed.netloc == "files.example.com" + assert parsed.path == "/files/tools/tool-file-id.png" + + +def test_verify_tool_file_signature_rejects_invalid_sign(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x02" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 10) + + url = sign_tool_file("tool-file-id", ".txt") + parsed = urlparse(url) + query = parse_qs(parsed.query) + timestamp = query["timestamp"][0] + nonce = query["nonce"][0] + sign = query["sign"][0] + + assert verify_tool_file_signature("tool-file-id", timestamp, nonce, "bad-signature") is False + + +def test_verify_tool_file_signature_rejects_expired_signature(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x02" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 10) + + url = sign_tool_file("tool-file-id", ".txt") + parsed = urlparse(url) + query = parse_qs(parsed.query) + timestamp = query["timestamp"][0] + nonce = query["nonce"][0] + sign = query["sign"][0] + + monkeypatch.setattr("core.tools.signature.time.time", lambda: int(timestamp) + 99) + assert verify_tool_file_signature("tool-file-id", timestamp, nonce, sign) is False + + +def test_sign_upload_file_prefers_internal_url(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x03" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + + url = sign_upload_file("upload-id", ".png") + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/upload-id/image-preview" + assert query["timestamp"][0] + assert query["nonce"][0] + assert query["sign"][0] + + +def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x05" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + + url = sign_upload_file("upload-id", ".png") + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "files.example.com" + assert parsed.path == "/files/upload-id/image-preview" + assert query["timestamp"][0] + assert query["nonce"][0] + assert query["sign"][0] diff --git a/api/tests/unit_tests/core/tools/test_tool_engine.py b/api/tests/unit_tests/core/tools/test_tool_engine.py new file mode 100644 index 0000000000..40c107667c --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_engine.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from collections.abc import Generator +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolInvokeMessageBinary, + ToolInvokeMeta, + ToolParameter, + ToolProviderType, +) +from core.tools.errors import ( + ToolEngineInvokeError, + ToolInvokeError, + ToolParameterValidationError, +) +from core.tools.tool_engine import ToolEngine + + +class _DummyTool(Tool): + result: Any + raise_error: Exception | None + + def __init__(self, entity: ToolEntity, runtime: ToolRuntime): + super().__init__(entity=entity, runtime=runtime) + self.result = [self.create_text_message("ok")] + self.raise_error = None + + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.BUILT_IN + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + if self.raise_error: + raise self.raise_error + if isinstance(self.result, list | Generator): + yield from self.result + else: + yield self.result + + +def _build_tool(with_llm_parameter: bool = False) -> _DummyTool: + parameters = [] + if with_llm_parameter: + parameters = [ + ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + ] + entity = ToolEntity( + identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), + parameters=parameters, + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER, runtime_parameters={"rt": 1}) + return _DummyTool(entity=entity, runtime=runtime) + + +def test_convert_tool_response_to_str_and_extract_binary_messages(): + tool = _build_tool() + messages = [ + tool.create_text_message("hello"), + tool.create_link_message("https://example.com"), + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE, + message=ToolInvokeMessage.TextMessage(text="https://example.com/a.png"), + meta={"mime_type": "image/png"}, + ), + tool.create_json_message({"a": 1}), + tool.create_json_message({"a": 1}, suppress_output=True), + ] + text = ToolEngine._convert_tool_response_to_str(messages) + assert "hello" in text + assert "result link: https://example.com." in text + assert '"a": 1' in text + + blob_message = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=ToolInvokeMessage.TextMessage(text="https://example.com/blob.bin"), + meta={"mime_type": "application/octet-stream"}, + ) + link_message = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=ToolInvokeMessage.TextMessage(text="https://example.com/file.pdf"), + meta={"mime_type": "application/pdf"}, + ) + binaries = list(ToolEngine._extract_tool_response_binary_and_text([messages[2], blob_message, link_message])) + assert [b.mimetype for b in binaries] == ["image/png", "application/octet-stream", "application/pdf"] + + with pytest.raises(ValueError, match="missing meta data"): + list( + ToolEngine._extract_tool_response_binary_and_text( + [ + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE, + message=ToolInvokeMessage.TextMessage(text="x"), + ) + ] + ) + ) + + +def test_create_message_files_and_invoke_generator(): + binaries = [ + ToolInvokeMessageBinary(mimetype="image/png", url="https://example.com/abc.png"), + ToolInvokeMessageBinary(mimetype="audio/wav", url="https://example.com/def.wav"), + ] + created = [] + + def _message_file_factory(**kwargs): + obj = SimpleNamespace(id=f"mf-{len(created) + 1}", **kwargs) + created.append(obj) + return obj + + with patch("core.tools.tool_engine.MessageFile", side_effect=_message_file_factory): + with patch("core.tools.tool_engine.db") as mock_db: + ids = ToolEngine._create_message_files( + tool_messages=binaries, + agent_message=SimpleNamespace(id="msg-1"), + invoke_from=InvokeFrom.DEBUGGER, + user_id="user-1", + ) + + assert ids == ["mf-1", "mf-2"] + assert mock_db.session.add.call_count == 2 + mock_db.session.close.assert_called_once() + + tool = _build_tool() + invoked = list(ToolEngine._invoke(tool, {"a": 1}, user_id="u")) + assert invoked[0].type == ToolInvokeMessage.MessageType.TEXT + assert isinstance(invoked[-1], ToolInvokeMeta) + assert invoked[-1].error is None + + +def test_generic_invoke_success_and_error_paths(): + tool = _build_tool() + callback = Mock() + callback.on_tool_execution.side_effect = lambda **kwargs: kwargs["tool_outputs"] + response = list( + ToolEngine.generic_invoke( + tool=tool, + tool_parameters={"x": 1}, + user_id="u1", + workflow_tool_callback=callback, + workflow_call_depth=0, + conversation_id="c1", + app_id="a1", + message_id="m1", + ) + ) + assert response[0].message.text == "ok" + callback.on_tool_start.assert_called_once() + callback.on_tool_execution.assert_called_once() + + tool.raise_error = RuntimeError("boom") + error_callback = Mock() + error_callback.on_tool_execution.side_effect = lambda **kwargs: list(kwargs["tool_outputs"]) + with pytest.raises(RuntimeError, match="boom"): + list( + ToolEngine.generic_invoke( + tool=tool, + tool_parameters={"x": 1}, + user_id="u1", + workflow_tool_callback=error_callback, + workflow_call_depth=0, + ) + ) + error_callback.on_tool_error.assert_called_once() + + +def test_agent_invoke_success(): + tool = _build_tool(with_llm_parameter=True) + callback = Mock() + message = SimpleNamespace(id="m1", conversation_id="c1") + meta = ToolInvokeMeta.empty() + + with patch.object(ToolEngine, "_invoke", return_value=iter([tool.create_text_message("ok"), meta])): + with patch( + "core.tools.tool_engine.ToolFileMessageTransformer.transform_tool_invoke_messages", + side_effect=lambda messages, **kwargs: messages, + ): + with patch.object(ToolEngine, "_extract_tool_response_binary_and_text", return_value=iter([])): + with patch.object(ToolEngine, "_create_message_files", return_value=[]): + result_text, message_files, result_meta = ToolEngine.agent_invoke( + tool=tool, + tool_parameters="hello", + user_id="u1", + tenant_id="tenant-1", + message=message, + invoke_from=InvokeFrom.DEBUGGER, + agent_tool_callback=callback, + ) + + assert result_text == "ok" + assert message_files == [] + assert result_meta.error is None + callback.on_tool_start.assert_called_once() + callback.on_tool_end.assert_called_once() + + +def test_agent_invoke_param_validation_error(): + tool = _build_tool(with_llm_parameter=True) + callback = Mock() + message = SimpleNamespace(id="m1", conversation_id="c1") + + with patch.object(ToolEngine, "_invoke", side_effect=ToolParameterValidationError("bad-param")): + error_text, files, error_meta = ToolEngine.agent_invoke( + tool=tool, + tool_parameters={"a": 1}, + user_id="u1", + tenant_id="tenant-1", + message=message, + invoke_from=InvokeFrom.DEBUGGER, + agent_tool_callback=callback, + ) + + assert "tool parameters validation error" in error_text + assert files == [] + assert error_meta.error + + +def test_agent_invoke_engine_meta_error(): + tool = _build_tool(with_llm_parameter=True) + callback = Mock() + message = SimpleNamespace(id="m1", conversation_id="c1") + engine_error = ToolEngineInvokeError(ToolInvokeMeta.error_instance("meta failure")) + + with patch.object(ToolEngine, "_invoke", side_effect=engine_error): + error_text, files, error_meta = ToolEngine.agent_invoke( + tool=tool, + tool_parameters={"a": 1}, + user_id="u1", + tenant_id="tenant-1", + message=message, + invoke_from=InvokeFrom.DEBUGGER, + agent_tool_callback=callback, + ) + + assert "meta failure" in error_text + assert files == [] + assert error_meta.error == "meta failure" + + +def test_agent_invoke_tool_invoke_error(): + tool = _build_tool(with_llm_parameter=True) + callback = Mock() + message = SimpleNamespace(id="m1", conversation_id="c1") + + with patch.object(ToolEngine, "_invoke", side_effect=ToolInvokeError("invoke boom")): + error_text, files, _ = ToolEngine.agent_invoke( + tool=tool, + tool_parameters={"a": 1}, + user_id="u1", + tenant_id="tenant-1", + message=message, + invoke_from=InvokeFrom.DEBUGGER, + agent_tool_callback=callback, + ) + + assert "tool invoke error" in error_text + assert files == [] diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py new file mode 100644 index 0000000000..cca8254dd6 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -0,0 +1,249 @@ +"""Unit tests for `ToolFileManager` behavior. + +Covers signing/verification, file persistence flows, and retrieval APIs with +mocked storage/session boundaries (httpx, SimpleNamespace, Mock/patch) to +avoid real IO. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import httpx +import pytest + +from core.tools.tool_file_manager import ToolFileManager + + +def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: + monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.tool_file_manager.os.urandom", lambda _: b"\x01" * 16) + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.SECRET_KEY", "secret") + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 100) + + url = ToolFileManager.sign_file("tf-1", ".png") + return dict(part.split("=", 1) for part in url.split("?", 1)[1].split("&")) + + +def _patch_session_factory(session: Mock): + session_cm = MagicMock() + session_cm.__enter__.return_value = session + session_cm.__exit__.return_value = False + return patch("core.tools.tool_file_manager.session_factory.create_session", return_value=session_cm) + + +def test_tool_file_manager_sign_verify_valid(monkeypatch: pytest.MonkeyPatch) -> None: + query = _setup_tool_file_signing(monkeypatch) + url = ToolFileManager.sign_file("tf-1", ".png") + assert "/files/tools/tf-1.png" in url + + assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is True + + +def test_tool_file_manager_sign_verify_bad_signature(monkeypatch: pytest.MonkeyPatch) -> None: + query = _setup_tool_file_signing(monkeypatch) + + assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], "bad") is False + + +def test_tool_file_manager_sign_verify_expired_timestamp(monkeypatch: pytest.MonkeyPatch) -> None: + query = _setup_tool_file_signing(monkeypatch) + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 0) + monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000100) + + assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is False + + +def test_create_file_by_raw_stores_file_and_persists_record() -> None: + manager = ToolFileManager() + session = Mock() + session.refresh.side_effect = lambda model: setattr(model, "id", "tf-1") + + def tool_file_factory(**kwargs): + return SimpleNamespace(**kwargs) + + with ( + patch("core.tools.tool_file_manager.storage") as storage, + patch("core.tools.tool_file_manager.ToolFile", side_effect=tool_file_factory), + patch("core.tools.tool_file_manager.guess_extension", return_value=".txt"), + patch("core.tools.tool_file_manager.uuid4", return_value=SimpleNamespace(hex="abc")), + _patch_session_factory(session), + ): + file_model = manager.create_file_by_raw( + user_id="u1", + tenant_id="t1", + conversation_id="c1", + file_binary=b"hello", + mimetype="text/plain", + filename="readme", + ) + + assert file_model.name.endswith(".txt") + storage.save.assert_called_once() + session.add.assert_called_once() + session.commit.assert_called_once() + session.refresh.assert_called_once_with(file_model) + + +def test_create_file_by_url_downloads_and_persists_record() -> None: + manager = ToolFileManager() + response = Mock() + response.content = b"binary" + response.headers = {"Content-Type": "application/octet-stream"} + response.raise_for_status.return_value = None + session = Mock() + + def tool_file_factory(**kwargs): + return SimpleNamespace(**kwargs) + + session.refresh.side_effect = lambda model: setattr(model, "id", "tf-2") + with ( + patch("core.tools.tool_file_manager.storage") as storage, + patch("core.tools.tool_file_manager.ToolFile", side_effect=tool_file_factory), + patch("core.tools.tool_file_manager.uuid4", return_value=SimpleNamespace(hex="def")), + _patch_session_factory(session), + patch("core.tools.tool_file_manager.ssrf_proxy.get", return_value=response), + ): + file_model = manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1") + + assert file_model.file_key.startswith("tools/t1/") + storage.save.assert_called_once() + session.add.assert_called_once_with(file_model) + session.commit.assert_called_once() + session.refresh.assert_called_once_with(file_model) + + +def test_create_file_by_url_raises_on_timeout() -> None: + manager = ToolFileManager() + + with patch("core.tools.tool_file_manager.ssrf_proxy.get", side_effect=httpx.TimeoutException("timeout")): + with pytest.raises(ValueError, match="timeout when downloading file"): + manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1") + + +def test_get_file_binary_returns_none_when_not_found() -> None: + # Arrange + manager = ToolFileManager() + session = Mock() + session.query.return_value.where.return_value.first.return_value = None + + # Act + with _patch_session_factory(session): + result = manager.get_file_binary("missing") + + # Assert + assert result is None + + +def test_get_file_binary_returns_bytes_when_found() -> None: + # Arrange + manager = ToolFileManager() + tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain") + session = Mock() + session.query.return_value.where.return_value.first.return_value = tool_file + + # Act + with patch("core.tools.tool_file_manager.storage") as storage: + storage.load_once.return_value = b"hello" + with _patch_session_factory(session): + result = manager.get_file_binary("id1") + + # Assert + assert result == (b"hello", "text/plain") + + +def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None: + # Arrange + manager = ToolFileManager() + session = Mock() + first_query = Mock() + second_query = Mock() + first_query.where.return_value.first.return_value = None + second_query.where.return_value.first.return_value = None + session.query.side_effect = [first_query, second_query] + + # Act + with _patch_session_factory(session): + result = manager.get_file_binary_by_message_file_id("mf-1") + + # Assert + assert result is None + + +def test_get_file_binary_by_message_file_id_when_url_is_none() -> None: + # Arrange + manager = ToolFileManager() + message_file = SimpleNamespace(url=None) + session = Mock() + first_query = Mock() + second_query = Mock() + first_query.where.return_value.first.return_value = message_file + second_query.where.return_value.first.return_value = None + session.query.side_effect = [first_query, second_query] + + # Act + with _patch_session_factory(session): + result = manager.get_file_binary_by_message_file_id("mf-1") + + # Assert + assert result is None + + +def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None: + # Arrange + manager = ToolFileManager() + message_file = SimpleNamespace(url="https://x/files/tools/tool123.png") + tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") + session = Mock() + first_query = Mock() + second_query = Mock() + first_query.where.return_value.first.return_value = message_file + second_query.where.return_value.first.return_value = tool_file + session.query.side_effect = [first_query, second_query] + + # Act + with patch("core.tools.tool_file_manager.storage") as storage: + storage.load_once.return_value = b"img" + with _patch_session_factory(session): + result = manager.get_file_binary_by_message_file_id("mf-1") + + # Assert + assert result == (b"img", "image/png") + + +def test_get_file_generator_returns_none_when_toolfile_missing() -> None: + # Arrange + manager = ToolFileManager() + session = Mock() + session.query.return_value.where.return_value.first.return_value = None + + # Act + with _patch_session_factory(session): + stream, tool_file = manager.get_file_generator_by_tool_file_id("tool123") + + # Assert + assert stream is None + assert tool_file is None + + +def test_get_file_generator_returns_stream_when_found() -> None: + # Arrange + manager = ToolFileManager() + tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") + session = Mock() + session.query.return_value.where.return_value.first.return_value = tool_file + + # Act + with patch("core.tools.tool_file_manager.storage") as storage: + stream = iter([b"a", b"b"]) + storage.load_stream.return_value = stream + with ( + _patch_session_factory(session), + patch("core.tools.tool_file_manager.ToolFilePydanticModel.model_validate", return_value="validated-file"), + ): + result_stream, result_file = manager.get_file_generator_by_tool_file_id("tool123") + assert list(result_stream) == [b"a", b"b"] + assert result_file == "validated-file" diff --git a/api/tests/unit_tests/core/tools/test_tool_label_manager.py b/api/tests/unit_tests/core/tools/test_tool_label_manager.py new file mode 100644 index 0000000000..857f4aa178 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_label_manager.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from unittest.mock import PropertyMock, patch + +import pytest + +from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.tools.custom_tool.provider import ApiToolProviderController +from core.tools.tool_label_manager import ToolLabelManager +from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + + +class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController): + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): + return None + + +def _api_controller(provider_id: str = "api-1") -> ApiToolProviderController: + controller = object.__new__(ApiToolProviderController) + controller.provider_id = provider_id + return controller + + +def _workflow_controller(provider_id: str = "wf-1") -> WorkflowToolProviderController: + controller = object.__new__(WorkflowToolProviderController) + controller.provider_id = provider_id + return controller + + +def test_tool_label_manager_filter_tool_labels(): + filtered = ToolLabelManager.filter_tool_labels(["search", "search", "invalid", "news"]) + assert set(filtered) == {"search", "news"} + assert len(filtered) == 2 + + +def test_tool_label_manager_update_tool_labels_db(): + controller = _api_controller("api-1") + with patch("core.tools.tool_label_manager.db") as mock_db: + delete_query = mock_db.session.query.return_value.where.return_value + delete_query.delete.return_value = None + ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"]) + + delete_query.delete.assert_called_once() + # only one valid unique label should be inserted. + assert mock_db.session.add.call_count == 1 + mock_db.session.commit.assert_called_once() + + +def test_tool_label_manager_update_tool_labels_unsupported(): + with pytest.raises(ValueError, match="Unsupported tool type"): + ToolLabelManager.update_tool_labels(object(), ["search"]) # type: ignore[arg-type] + + +def test_tool_label_manager_get_tool_labels_for_builtin_and_db(): + with patch.object( + _ConcreteBuiltinToolProviderController, + "tool_labels", + new_callable=PropertyMock, + return_value=["search", "news"], + ): + builtin = object.__new__(_ConcreteBuiltinToolProviderController) + assert ToolLabelManager.get_tool_labels(builtin) == ["search", "news"] + + api = _api_controller("api-1") + with patch("core.tools.tool_label_manager.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = ["search", "news"] + labels = ToolLabelManager.get_tool_labels(api) + assert labels == ["search", "news"] + + with pytest.raises(ValueError, match="Unsupported tool type"): + ToolLabelManager.get_tool_labels(object()) # type: ignore[arg-type] + + +def test_tool_label_manager_get_tools_labels_batch(): + assert ToolLabelManager.get_tools_labels([]) == {} + + api = _api_controller("api-1") + wf = _workflow_controller("wf-1") + records = [ + SimpleNamespace(tool_id="api-1", label_name="search"), + SimpleNamespace(tool_id="api-1", label_name="news"), + SimpleNamespace(tool_id="wf-1", label_name="utilities"), + ] + with patch("core.tools.tool_label_manager.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = records + labels = ToolLabelManager.get_tools_labels([api, wf]) + assert labels == {"api-1": ["search", "news"], "wf-1": ["utilities"]} + + with pytest.raises(ValueError, match="Unsupported tool type"): + ToolLabelManager.get_tools_labels([api, object()]) # type: ignore[list-item] diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py new file mode 100644 index 0000000000..0f73e22654 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -0,0 +1,899 @@ +from __future__ import annotations + +"""Unit tests for ToolManager behavior with mocked providers and collaborators.""" + +import json +import threading +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.plugin.entities.plugin_daemon import CredentialType +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolParameter, + ToolProviderType, +) +from core.tools.errors import ToolProviderNotFoundError +from core.tools.plugin_tool.provider import PluginToolProviderController +from core.tools.tool_manager import ToolManager + + +class _SimpleContextVar: + def __init__(self): + self._is_set = False + self._value: Any = None + + def get(self): + if not self._is_set: + raise LookupError + return self._value + + def set(self, value: Any): + self._value = value + self._is_set = True + + +def _cm(session: Any): + context = Mock() + context.__enter__ = Mock(return_value=session) + context.__exit__ = Mock(return_value=False) + return context + + +def _setup_list_providers_from_api_mocks( + monkeypatch, + *, + session: Mock, + hardcoded_controller: SimpleNamespace, + plugin_controller: PluginToolProviderController, + api_controller: SimpleNamespace, + workflow_controller: SimpleNamespace, +): + mock_db = Mock() + mock_db.engine = object() + monkeypatch.setattr("core.tools.tool_manager.db", mock_db) + monkeypatch.setattr("core.tools.tool_manager.Session", lambda *args, **kwargs: _cm(session)) + monkeypatch.setattr( + ToolManager, + "list_builtin_providers", + Mock(return_value=[hardcoded_controller, plugin_controller]), + ) + monkeypatch.setattr( + ToolManager, + "list_default_builtin_providers", + Mock(return_value=[SimpleNamespace(provider="hardcoded")]), + ) + monkeypatch.setattr("core.tools.tool_manager.is_filtered", lambda *args, **kwargs: False) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.builtin_provider_to_user_provider", + lambda **kwargs: SimpleNamespace(name=kwargs["provider_controller"].entity.identity.name), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.api_provider_to_controller", + Mock(side_effect=[api_controller, RuntimeError("invalid")]), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.api_provider_to_user_provider", + Mock(return_value=SimpleNamespace(name="api-provider")), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.workflow_provider_to_controller", + Mock(side_effect=[workflow_controller, RuntimeError("deleted app")]), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.workflow_provider_to_user_provider", + Mock(return_value=SimpleNamespace(name="workflow-provider")), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolLabelManager.get_tools_labels", + Mock(side_effect=[{"api-1": ["search"]}, {"wf-1": ["utility"]}]), + ) + mock_mcp_service = Mock() + mock_mcp_service.list_providers.return_value = [SimpleNamespace(name="mcp-provider")] + monkeypatch.setattr("core.tools.tool_manager.MCPToolManageService", Mock(return_value=mock_mcp_service)) + monkeypatch.setattr("core.tools.tool_manager.BuiltinToolProviderSort.sort", lambda providers: providers) + + +@pytest.fixture(autouse=True) +def _reset_tool_manager_state(): + old_hardcoded = ToolManager._hardcoded_providers.copy() + old_loaded = ToolManager._builtin_providers_loaded + old_labels = ToolManager._builtin_tools_labels.copy() + try: + yield + finally: + ToolManager._hardcoded_providers = old_hardcoded + ToolManager._builtin_providers_loaded = old_loaded + ToolManager._builtin_tools_labels = old_labels + + +def test_get_hardcoded_provider_loads_cache_when_empty(): + provider = Mock() + ToolManager._hardcoded_providers = {} + + def _load(): + ToolManager._hardcoded_providers["weather"] = provider + + with patch.object(ToolManager, "load_hardcoded_providers_cache", side_effect=_load) as mock_load: + assert ToolManager.get_hardcoded_provider("weather") is provider + + mock_load.assert_called_once() + + +def test_get_builtin_provider_returns_plugin_for_missing_hardcoded(): + hardcoded = Mock() + plugin_provider = Mock() + ToolManager._hardcoded_providers = {"time": hardcoded} + + with patch.object(ToolManager, "get_plugin_provider", return_value=plugin_provider): + assert ToolManager.get_builtin_provider("time", "tenant-1") is hardcoded + assert ToolManager.get_builtin_provider("plugin/time", "tenant-1") is plugin_provider + + +def test_get_plugin_provider_uses_context_cache(): + provider_context = _SimpleContextVar() + lock_context = _SimpleContextVar() + lock_context.set(threading.Lock()) + provider_entity = SimpleNamespace(declaration=Mock(), plugin_id="pid", plugin_unique_identifier="uid") + + with patch("core.tools.tool_manager.contexts.plugin_tool_providers", provider_context): + with patch("core.tools.tool_manager.contexts.plugin_tool_providers_lock", lock_context): + with patch("core.tools.tool_manager.PluginToolManager") as mock_manager_cls: + mock_manager_cls.return_value.fetch_tool_provider.return_value = provider_entity + controller = SimpleNamespace(name="controller") + with patch("core.tools.tool_manager.PluginToolProviderController", return_value=controller): + built = ToolManager.get_plugin_provider("provider-a", "tenant-1") + cached = ToolManager.get_plugin_provider("provider-a", "tenant-1") + + assert built is controller + assert cached is controller + mock_manager_cls.return_value.fetch_tool_provider.assert_called_once() + + +def test_get_plugin_provider_raises_when_provider_missing(): + provider_context = _SimpleContextVar() + lock_context = _SimpleContextVar() + lock_context.set(threading.Lock()) + + with patch("core.tools.tool_manager.contexts.plugin_tool_providers", provider_context): + with patch("core.tools.tool_manager.contexts.plugin_tool_providers_lock", lock_context): + with patch("core.tools.tool_manager.PluginToolManager") as mock_manager_cls: + mock_manager_cls.return_value.fetch_tool_provider.return_value = None + with pytest.raises(ToolProviderNotFoundError, match="plugin provider provider-a not found"): + ToolManager.get_plugin_provider("provider-a", "tenant-1") + + +def test_get_tool_runtime_builtin_without_credentials(): + tool = Mock() + tool.fork_tool_runtime.return_value = "runtime-tool" + controller = SimpleNamespace(get_tool=Mock(return_value=tool), need_credentials=False) + + with patch.object(ToolManager, "get_builtin_provider", return_value=controller): + result = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="current_time", + tenant_id="tenant-1", + ) + + assert result == "runtime-tool" + runtime = tool.fork_tool_runtime.call_args.kwargs["runtime"] + assert runtime.tenant_id == "tenant-1" + assert runtime.credentials == {} + + +def test_get_tool_runtime_builtin_missing_tool_raises(): + controller = SimpleNamespace(get_tool=Mock(return_value=None), need_credentials=False) + + with patch.object(ToolManager, "get_builtin_provider", return_value=controller): + with pytest.raises(ToolProviderNotFoundError, match="builtin tool missing not found"): + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="missing", + tenant_id="tenant-1", + ) + + +def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks(): + tool = Mock() + tool.fork_tool_runtime.return_value = "runtime-tool" + controller = SimpleNamespace( + get_tool=Mock(return_value=tool), + need_credentials=True, + get_credentials_schema_by_type=Mock(return_value=[]), + ) + builtin_provider = SimpleNamespace( + id="cred-1", + credential_type=CredentialType.API_KEY.value, + credentials={"encrypted": "value"}, + expires_at=-1, + user_id="user-1", + ) + + with patch.object(ToolManager, "get_builtin_provider", return_value=controller): + with patch("core.helper.credential_utils.check_credential_policy_compliance"): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + builtin_provider + ) + encrypter = Mock() + encrypter.decrypt.return_value = {"api_key": "secret"} + cache = Mock() + with patch("core.tools.tool_manager.create_provider_encrypter", return_value=(encrypter, cache)): + result = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="weekday", + tenant_id="tenant-1", + ) + + assert result == "runtime-tool" + runtime = tool.fork_tool_runtime.call_args.kwargs["runtime"] + assert runtime.credentials == {"api_key": "secret"} + assert runtime.credential_type == CredentialType.API_KEY + + +@patch("core.tools.tool_manager.create_provider_encrypter") +@patch("core.plugin.impl.oauth.OAuthHandler") +@patch( + "services.tools.builtin_tools_manage_service.BuiltinToolManageService.get_oauth_client", + return_value={"client_id": "id"}, +) +@patch("core.tools.tool_manager.db") +@patch("core.tools.tool_manager.time.time", return_value=1000) +@patch("core.helper.credential_utils.check_credential_policy_compliance") +def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials( + mock_check, + mock_time, + mock_db, + mock_get_oauth_client, + mock_oauth_handler_cls, + mock_create_provider_encrypter, +): + tool = Mock() + tool.fork_tool_runtime.return_value = "runtime-tool" + controller = SimpleNamespace( + get_tool=Mock(return_value=tool), + need_credentials=True, + get_credentials_schema_by_type=Mock(return_value=[]), + ) + builtin_provider = SimpleNamespace( + id="cred-1", + credential_type=CredentialType.OAUTH2.value, + credentials={"encrypted": "value"}, + encrypted_credentials=None, + expires_at=1, + user_id="user-1", + ) + refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456) + + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider + encrypter = Mock() + encrypter.decrypt.return_value = {"token": "old"} + encrypter.encrypt.return_value = {"token": "encrypted"} + cache = Mock() + mock_create_provider_encrypter.return_value = (encrypter, cache) + mock_oauth_handler_cls.return_value.refresh_credentials.return_value = refreshed + + with patch.object(ToolManager, "get_builtin_provider", return_value=controller): + result = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="weekday", + tenant_id="tenant-1", + ) + + assert result == "runtime-tool" + assert builtin_provider.expires_at == refreshed.expires_at + assert builtin_provider.encrypted_credentials == json.dumps({"token": "encrypted"}) + mock_db.session.commit.assert_called_once() + cache.delete.assert_called_once() + + +def test_get_tool_runtime_builtin_plugin_provider_deleted_raises(): + plugin_controller = object.__new__(PluginToolProviderController) + plugin_controller.entity = SimpleNamespace(credentials_schema=[{"name": "k"}], oauth_schema=None) + plugin_controller.get_tool = Mock(return_value=Mock()) + plugin_controller.get_credentials_schema_by_type = Mock(return_value=[]) + + with patch.object(ToolManager, "get_builtin_provider", return_value=plugin_controller): + with patch("core.tools.tool_manager.is_valid_uuid", return_value=True): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.scalar.return_value = None + with pytest.raises(ToolProviderNotFoundError, match="provider has been deleted"): + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="weekday", + tenant_id="tenant-1", + credential_id="uuid-id", + ) + + +def test_get_tool_runtime_api_path(): + api_tool = Mock() + api_tool.fork_tool_runtime.return_value = "api-runtime" + api_provider = Mock() + api_provider.get_tool.return_value = api_tool + + with patch.object(ToolManager, "get_api_provider_controller", return_value=(api_provider, {"c": "enc"})): + encrypter = Mock() + encrypter.decrypt.return_value = {"c": "dec"} + with patch("core.tools.tool_manager.create_tool_provider_encrypter", return_value=(encrypter, Mock())): + assert ( + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + ) + == "api-runtime" + ) + + +def test_get_tool_runtime_workflow_path(): + workflow_provider = SimpleNamespace(tenant_id="tenant-1") + workflow_tool = Mock() + workflow_tool.fork_tool_runtime.return_value = "wf-runtime" + workflow_controller = Mock() + workflow_controller.get_tools.return_value = [workflow_tool] + session = Mock() + session.begin.return_value = _cm(None) + session.scalar.return_value = workflow_provider + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + with patch( + "core.tools.tool_manager.ToolTransformService.workflow_provider_to_controller", + return_value=workflow_controller, + ): + assert ( + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.WORKFLOW, + provider_id="wf-1", + tool_name="wf", + tenant_id="tenant-1", + ) + == "wf-runtime" + ) + + +def test_get_tool_runtime_plugin_path(): + with patch.object( + ToolManager, + "get_plugin_provider", + return_value=SimpleNamespace(get_tool=lambda _: "plugin-tool"), + ): + assert ( + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.PLUGIN, + provider_id="plugin-1", + tool_name="p", + tenant_id="tenant-1", + ) + == "plugin-tool" + ) + + +def test_get_tool_runtime_mcp_path(): + with patch.object( + ToolManager, + "get_mcp_provider_controller", + return_value=SimpleNamespace(get_tool=lambda _: "mcp-tool"), + ): + assert ( + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.MCP, + provider_id="mcp-1", + tool_name="m", + tenant_id="tenant-1", + ) + == "mcp-tool" + ) + + +def test_get_tool_runtime_app_not_implemented(): + with pytest.raises(NotImplementedError, match="app provider not implemented"): + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.APP, + provider_id="app", + tool_name="x", + tenant_id="tenant-1", + ) + + +def test_get_agent_runtime_apply_runtime_parameters(): + parameter = ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + parameter.form = ToolParameter.ToolParameterForm.FORM + + tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) + tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter]) + + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime): + with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}): + manager = Mock() + manager.decrypt_tool_parameters.return_value = {"query": "decrypted"} + with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=manager): + agent_tool = SimpleNamespace( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tool_parameters={"query": "hello"}, + credential_id=None, + ) + result = ToolManager.get_agent_tool_runtime( + tenant_id="tenant-1", + app_id="app-1", + agent_tool=agent_tool, + invoke_from=InvokeFrom.DEBUGGER, + variable_pool=None, + ) + + assert result is tool_runtime + assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted" + + +def test_get_workflow_runtime_apply_runtime_parameters(): + parameter = ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + parameter.form = ToolParameter.ToolParameterForm.FORM + + workflow_tool = SimpleNamespace( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tool_configurations={"query": "hello"}, + credential_id=None, + ) + tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) + tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter]) + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2): + with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}): + manager = Mock() + manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"} + with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=manager): + workflow_result = ToolManager.get_workflow_tool_runtime( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + workflow_tool=workflow_tool, + invoke_from=InvokeFrom.DEBUGGER, + variable_pool=None, + ) + + assert workflow_result is tool_runtime2 + assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec" + + +def test_get_agent_runtime_raises_when_runtime_missing(): + tool_runtime = SimpleNamespace(runtime=None, get_merged_runtime_parameters=lambda: []) + agent_tool = SimpleNamespace( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tool_parameters={}, + credential_id=None, + ) + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime): + with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={}): + with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=Mock()): + with pytest.raises(ValueError, match="runtime not found"): + ToolManager.get_agent_tool_runtime( + tenant_id="tenant-1", + app_id="app-1", + agent_tool=agent_tool, + ) + + +def test_get_tool_runtime_from_plugin_only_uses_form_parameters(): + form_param = ToolParameter.get_simple_instance( + name="q", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + form_param.form = ToolParameter.ToolParameterForm.FORM + llm_param = ToolParameter.get_simple_instance( + name="llm", + llm_description="llm", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + llm_param.form = ToolParameter.ToolParameterForm.LLM + + tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) + tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param]) + + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity): + result = ToolManager.get_tool_runtime_from_plugin( + tool_type=ToolProviderType.API, + tenant_id="tenant-1", + provider="api-1", + tool_name="search", + tool_parameters={"q": "hello", "llm": "ignore"}, + ) + + assert result is tool_entity + assert tool_entity.runtime.runtime_parameters == {"q": "hello"} + + +def test_hardcoded_provider_icon_success(): + provider = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(icon="icon.svg"))) + with patch.object(ToolManager, "get_hardcoded_provider", return_value=provider): + with patch("core.tools.tool_manager.path.exists", return_value=True): + with patch("core.tools.tool_manager.mimetypes.guess_type", return_value=("image/svg+xml", None)): + icon_path, mime = ToolManager.get_hardcoded_provider_icon("time") + assert icon_path.endswith("icon.svg") + assert mime == "image/svg+xml" + + +def test_hardcoded_provider_icon_missing_raises(): + provider = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(icon="icon.svg"))) + with patch.object(ToolManager, "get_hardcoded_provider", return_value=provider): + with patch("core.tools.tool_manager.path.exists", return_value=False): + with pytest.raises(ToolProviderNotFoundError, match="icon not found"): + ToolManager.get_hardcoded_provider_icon("time") + + +def test_list_hardcoded_providers_cache_hit(): + ToolManager._hardcoded_providers = {"p": Mock()} + ToolManager._builtin_providers_loaded = True + assert list(ToolManager.list_hardcoded_providers()) == list(ToolManager._hardcoded_providers.values()) + + +def test_clear_hardcoded_providers_cache_resets(): + ToolManager._hardcoded_providers = {"p": Mock()} + ToolManager._builtin_providers_loaded = True + ToolManager.clear_hardcoded_providers_cache() + assert ToolManager._hardcoded_providers == {} + assert ToolManager._builtin_providers_loaded is False + + +def test_list_hardcoded_providers_internal_loader(): + good_provider = SimpleNamespace( + entity=SimpleNamespace(identity=SimpleNamespace(name="good")), + get_tools=lambda: [SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="tool-a", label="A")))], + ) + provider_class = Mock(return_value=good_provider) + + with patch("core.tools.tool_manager.listdir", return_value=["good", "bad", "__skip"]): + with patch("core.tools.tool_manager.path.isdir", side_effect=lambda p: "good" in p or "bad" in p): + with patch( + "core.tools.tool_manager.load_single_subclass_from_source", + side_effect=[provider_class, RuntimeError("boom")], + ): + ToolManager._hardcoded_providers = {} + ToolManager._builtin_tools_labels = {} + providers = list(ToolManager._list_hardcoded_providers()) + + assert providers == [good_provider] + assert ToolManager._hardcoded_providers["good"] is good_provider + assert ToolManager._builtin_tools_labels["tool-a"] == "A" + assert ToolManager._builtin_providers_loaded is True + + +def test_get_tool_label_loads_cache_and_handles_missing(): + ToolManager._builtin_tools_labels = {} + + def _load(): + ToolManager._builtin_tools_labels["tool-a"] = "Label A" + + with patch.object(ToolManager, "load_hardcoded_providers_cache", side_effect=_load): + assert ToolManager.get_tool_label("tool-a") == "Label A" + assert ToolManager.get_tool_label("missing") is None + + +def test_list_default_builtin_providers_for_postgres_and_mysql(): + provider_records = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")] + + for scheme in ("postgresql", "mysql"): + session = Mock() + session.execute.return_value.all.return_value = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")] + session.query.return_value.where.return_value.all.return_value = provider_records + + with patch("core.tools.tool_manager.dify_config", SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME=scheme)): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + providers = ToolManager.list_default_builtin_providers("tenant-1") + + assert providers == provider_records + + +def test_list_providers_from_api_covers_builtin_api_workflow_and_mcp(monkeypatch): + hardcoded_controller = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="hardcoded"))) + plugin_controller = object.__new__(PluginToolProviderController) + plugin_controller.entity = SimpleNamespace(identity=SimpleNamespace(name="plugin-provider")) + + api_db_provider_good = SimpleNamespace(id="api-1") + api_db_provider_bad = SimpleNamespace(id="api-2") + api_controller = SimpleNamespace(provider_id="api-1") + + workflow_db_provider_good = SimpleNamespace(id="wf-1") + workflow_db_provider_bad = SimpleNamespace(id="wf-2") + workflow_controller = SimpleNamespace(provider_id="wf-1") + + session = Mock() + session.scalars.side_effect = [ + SimpleNamespace(all=lambda: [api_db_provider_good, api_db_provider_bad]), + SimpleNamespace(all=lambda: [workflow_db_provider_good, workflow_db_provider_bad]), + ] + + _setup_list_providers_from_api_mocks( + monkeypatch, + session=session, + hardcoded_controller=hardcoded_controller, + plugin_controller=plugin_controller, + api_controller=api_controller, + workflow_controller=workflow_controller, + ) + providers = ToolManager.list_providers_from_api(user_id="user-1", tenant_id="tenant-1", typ="") + + names = {provider.name for provider in providers} + assert {"hardcoded", "plugin-provider", "api-provider", "workflow-provider", "mcp-provider"} <= names + + +def test_get_api_provider_controller_returns_controller_and_credentials(): + provider = SimpleNamespace( + id="api-1", + tenant_id="tenant-1", + name="api-provider", + description="desc", + credentials={"auth_type": "api_key_query"}, + credentials_str='{"auth_type": "api_key_query", "api_key_value": "secret"}', + schema_type="openapi", + schema="schema", + tools=[], + icon='{"background": "#000", "content": "A"}', + privacy_policy="privacy", + custom_disclaimer="disclaimer", + ) + db_query = Mock() + db_query.where.return_value.first.return_value = provider + controller = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value = db_query + with patch( + "core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller + ) as mock_from_db: + built_controller, credentials = ToolManager.get_api_provider_controller("tenant-1", "api-1") + + assert built_controller is controller + assert credentials == provider.credentials + mock_from_db.assert_called_with(provider, ApiProviderAuthType.API_KEY_QUERY) + controller.load_bundled_tools.assert_called_once_with(provider.tools) + + +def test_user_get_api_provider_masks_credentials_and_adds_labels(): + provider = SimpleNamespace( + id="api-1", + tenant_id="tenant-1", + name="api-provider", + description="desc", + credentials={"auth_type": "api_key_query"}, + credentials_str='{"auth_type": "api_key_query", "api_key_value": "secret"}', + schema_type="openapi", + schema="schema", + tools=[], + icon='{"background": "#000", "content": "A"}', + privacy_policy="privacy", + custom_disclaimer="disclaimer", + ) + db_query = Mock() + db_query.where.return_value.first.return_value = provider + controller = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value = db_query + with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller): + encrypter = Mock() + encrypter.decrypt.return_value = {"api_key_value": "secret"} + encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"} + with patch("core.tools.tool_manager.create_tool_provider_encrypter", return_value=(encrypter, Mock())): + with patch("core.tools.tool_manager.ToolLabelManager.get_tool_labels", return_value=["search"]): + user_payload = ToolManager.user_get_api_provider("api-provider", "tenant-1") + + assert user_payload["credentials"]["api_key_value"] == "***" + assert user_payload["labels"] == ["search"] + + +def test_get_api_provider_controller_not_found_raises(): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = None + with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"): + ToolManager.get_api_provider_controller("tenant-1", "missing") + + +def test_get_mcp_provider_controller_returns_controller(): + provider_entity = SimpleNamespace(provider_icon={"background": "#111", "content": "M"}) + controller = Mock() + session = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls: + mock_service = mock_service_cls.return_value + mock_service.get_provider.return_value = provider_entity + with patch("core.tools.tool_manager.MCPToolProviderController.from_db", return_value=controller): + built = ToolManager.get_mcp_provider_controller("tenant-1", "mcp-1") + assert built is controller + + +def test_generate_mcp_tool_icon_url_returns_provider_icon(): + provider_entity = SimpleNamespace(provider_icon={"background": "#111", "content": "M"}) + session = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls: + mock_service = mock_service_cls.return_value + mock_service.get_provider_entity.return_value = provider_entity + assert ToolManager.generate_mcp_tool_icon_url("tenant-1", "mcp-1") == provider_entity.provider_icon + + +def test_get_mcp_provider_controller_missing_raises(): + session = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls: + mock_service_cls.return_value.get_provider.side_effect = ValueError("missing") + with pytest.raises(ToolProviderNotFoundError, match="mcp provider mcp-1 not found"): + ToolManager.get_mcp_provider_controller("tenant-1", "mcp-1") + + +def test_generate_tool_icon_urls_for_builtin_and_plugin(): + with patch("core.tools.tool_manager.dify_config.CONSOLE_API_URL", "https://console.example.com"): + builtin_url = ToolManager.generate_builtin_tool_icon_url("time") + plugin_url = ToolManager.generate_plugin_tool_icon_url("tenant-1", "icon.svg") + + assert builtin_url.endswith("/tool-provider/builtin/time/icon") + assert "/plugin/icon" in plugin_url + + +def test_generate_tool_icon_urls_for_workflow_and_api(): + workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}') + api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}') + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider] + assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"} + assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"} + + +def test_generate_tool_icon_urls_missing_workflow_and_api_use_default(): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = None + assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525" + assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525" + + +def test_get_tool_icon_for_builtin_provider_variants(): + plugin_provider = object.__new__(PluginToolProviderController) + plugin_provider.entity = SimpleNamespace(identity=SimpleNamespace(icon="plugin.svg")) + + with patch.object(ToolManager, "get_builtin_provider", return_value=plugin_provider): + with patch.object(ToolManager, "generate_plugin_tool_icon_url", return_value="plugin-icon"): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.BUILT_IN, "plugin-provider") == "plugin-icon" + + with patch.object(ToolManager, "get_builtin_provider", return_value=SimpleNamespace()): + with patch.object(ToolManager, "generate_builtin_tool_icon_url", return_value="builtin-icon"): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.BUILT_IN, "time") == "builtin-icon" + + +def test_get_tool_icon_for_api_workflow_and_mcp(): + with patch.object(ToolManager, "generate_api_tool_icon_url", return_value={"background": "#000"}): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.API, "api-1") == {"background": "#000"} + + with patch.object(ToolManager, "generate_workflow_tool_icon_url", return_value={"background": "#111"}): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.WORKFLOW, "wf-1") == {"background": "#111"} + + with patch.object(ToolManager, "generate_mcp_tool_icon_url", return_value={"background": "#222"}): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.MCP, "mcp-1") == {"background": "#222"} + + +def test_get_tool_icon_plugin_error_returns_default(): + plugin_provider = object.__new__(PluginToolProviderController) + plugin_provider.entity = SimpleNamespace(identity=SimpleNamespace(icon="plugin.svg")) + + with patch.object(ToolManager, "get_plugin_provider", return_value=plugin_provider): + with patch.object(ToolManager, "generate_plugin_tool_icon_url", side_effect=RuntimeError("fail")): + icon = ToolManager.get_tool_icon("tenant-1", ToolProviderType.PLUGIN, "plugin-provider") + assert icon["background"] == "#252525" + + +def test_get_tool_icon_invalid_provider_type_raises(): + with pytest.raises(ValueError, match="provider type"): + ToolManager.get_tool_icon("tenant-1", "invalid", "x") # type: ignore[arg-type] + + +def test_convert_tool_parameters_type_agent_and_workflow_branches(): + file_param = ToolParameter.get_simple_instance( + name="file", + llm_description="file", + typ=ToolParameter.ToolParameterType.FILE, + required=True, + ) + file_param.form = ToolParameter.ToolParameterForm.FORM + + with pytest.raises(ValueError, match="file type parameter file not supported in agent"): + ToolManager._convert_tool_parameters_type( + parameters=[file_param], + variable_pool=None, + tool_configurations={"file": "x"}, + typ="agent", + ) + + text_param = ToolParameter.get_simple_instance( + name="text", + llm_description="text", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + text_param.form = ToolParameter.ToolParameterForm.FORM + plain = ToolManager._convert_tool_parameters_type( + parameters=[text_param], + variable_pool=None, + tool_configurations={"text": "hello"}, + typ="workflow", + ) + assert plain == {"text": "hello"} + + variable_pool = Mock() + variable_pool.get.return_value = SimpleNamespace(value="from-variable") + variable_pool.convert_template.return_value = SimpleNamespace(text="from-template") + + mixed = ToolManager._convert_tool_parameters_type( + parameters=[text_param], + variable_pool=variable_pool, + tool_configurations={"text": {"type": "mixed", "value": "Hello {{name}}"}}, + typ="workflow", + ) + assert mixed == {"text": "from-template"} + + variable = ToolManager._convert_tool_parameters_type( + parameters=[text_param], + variable_pool=variable_pool, + tool_configurations={"text": {"type": "variable", "value": ["sys", "query"]}}, + typ="workflow", + ) + assert variable == {"text": "from-variable"} + + +def test_convert_tool_parameters_type_constant_branch(): + text_param = ToolParameter.get_simple_instance( + name="text", + llm_description="text", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + text_param.form = ToolParameter.ToolParameterForm.FORM + variable_pool = Mock() + + constant = ToolManager._convert_tool_parameters_type( + parameters=[text_param], + variable_pool=variable_pool, + tool_configurations={"text": {"type": "constant", "value": "fixed"}}, + typ="workflow", + ) + + assert constant == {"text": "fixed"} diff --git a/api/tests/unit_tests/core/tools/test_tool_provider_controller.py b/api/tests/unit_tests/core/tools/test_tool_provider_controller.py new file mode 100644 index 0000000000..30b8494c92 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_provider_controller.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import Any + +import pytest + +from core.entities.provider_entities import ProviderConfig +from core.tools.__base.tool import Tool +from core.tools.__base.tool_provider import ToolProviderController +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolProviderEntity, + ToolProviderIdentity, + ToolProviderType, +) +from core.tools.errors import ToolProviderCredentialValidationError + + +class _DummyTool(Tool): + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.BUILT_IN + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("ok") + + +class _DummyController(ToolProviderController): + def get_tool(self, tool_name: str) -> Tool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name=tool_name, + label=I18nObject(en_US=tool_name), + provider="provider", + ), + parameters=[], + ) + return _DummyTool(entity=entity, runtime=ToolRuntime(tenant_id="tenant")) + + +def _provider_identity() -> ToolProviderIdentity: + return ToolProviderIdentity( + author="author", + name="provider", + description=I18nObject(en_US="desc"), + icon="icon.svg", + label=I18nObject(en_US="Provider"), + ) + + +def test_tool_provider_controller_get_credentials_schema_returns_deep_copy(): + entity = ToolProviderEntity( + identity=_provider_identity(), + credentials_schema=[ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="api_key", required=False)], + ) + controller = _DummyController(entity=entity) + + schema = controller.get_credentials_schema() + schema[0].name = "changed" + + assert controller.entity.credentials_schema[0].name == "api_key" + + +def test_tool_provider_controller_default_provider_type(): + entity = ToolProviderEntity(identity=_provider_identity(), credentials_schema=[]) + controller = _DummyController(entity=entity) + + assert controller.provider_type == ToolProviderType.BUILT_IN + + +def test_validate_credentials_format_covers_required_default_and_type_rules(): + select_options = [ProviderConfig.Option(value="opt-a", label=I18nObject(en_US="A"))] + entity = ToolProviderEntity( + identity=_provider_identity(), + credentials_schema=[ + ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="required_text", required=True), + ProviderConfig(type=ProviderConfig.Type.SECRET_INPUT, name="secret", required=False), + ProviderConfig(type=ProviderConfig.Type.SELECT, name="choice", required=False, options=select_options), + ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="with_default", required=False, default="x"), + ], + ) + controller = _DummyController(entity=entity) + + credentials = {"required_text": "value", "secret": None, "choice": "opt-a"} + controller.validate_credentials_format(credentials) + assert credentials["with_default"] == "x" + + with pytest.raises(ToolProviderCredentialValidationError, match="not found"): + controller.validate_credentials_format({"required_text": "value", "unknown": "v"}) + + with pytest.raises(ToolProviderCredentialValidationError, match="is required"): + controller.validate_credentials_format({"secret": "s"}) + + with pytest.raises(ToolProviderCredentialValidationError, match="should be string"): + controller.validate_credentials_format({"required_text": 123}) # type: ignore[arg-type] + + with pytest.raises(ToolProviderCredentialValidationError, match="should be one of"): + controller.validate_credentials_format({"required_text": "value", "choice": "opt-b"}) diff --git a/api/tests/unit_tests/core/tools/utils/test_configuration.py b/api/tests/unit_tests/core/tools/utils/test_configuration.py new file mode 100644 index 0000000000..5ceaa08893 --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_configuration.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import Any +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.helper.tool_parameter_cache import ToolParameterCache +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) +from core.tools.utils.configuration import ToolParameterConfigurationManager + + +class _DummyTool(Tool): + runtime_overrides: list[ToolParameter] + + def __init__(self, entity: ToolEntity, runtime: ToolRuntime, runtime_overrides: list[ToolParameter]): + super().__init__(entity=entity, runtime=runtime) + self.runtime_overrides = runtime_overrides + + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.BUILT_IN + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("ok") + + def get_runtime_parameters( + self, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> list[ToolParameter]: + return self.runtime_overrides + + +def _param( + name: str, + *, + typ: ToolParameter.ToolParameterType, + form: ToolParameter.ToolParameterForm, + required: bool = False, +) -> ToolParameter: + return ToolParameter( + name=name, + label=I18nObject(en_US=name), + placeholder=I18nObject(en_US=""), + human_description=I18nObject(en_US=""), + type=typ, + form=form, + required=required, + default=None, + ) + + +def _build_manager() -> ToolParameterConfigurationManager: + base_params = [ + _param("secret", typ=ToolParameter.ToolParameterType.SECRET_INPUT, form=ToolParameter.ToolParameterForm.FORM), + _param("plain", typ=ToolParameter.ToolParameterType.STRING, form=ToolParameter.ToolParameterForm.FORM), + ] + runtime_overrides = [ + _param("secret", typ=ToolParameter.ToolParameterType.SECRET_INPUT, form=ToolParameter.ToolParameterForm.FORM), + _param("runtime_only", typ=ToolParameter.ToolParameterType.STRING, form=ToolParameter.ToolParameterForm.FORM), + ] + entity = ToolEntity( + identity=ToolIdentity(author="a", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), + parameters=base_params, + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + tool = _DummyTool(entity=entity, runtime=runtime, runtime_overrides=runtime_overrides) + return ToolParameterConfigurationManager( + tenant_id="tenant-1", + tool_runtime=tool, + provider_name="provider-a", + provider_type=ToolProviderType.BUILT_IN, + identity_id="ID.1", + ) + + +def test_merge_and_mask_parameters(): + manager = _build_manager() + + masked = manager.mask_tool_parameters({"secret": "abcdefghi", "plain": "x", "runtime_only": "y"}) + assert masked["secret"] == "ab*****hi" + assert masked["plain"] == "x" + assert masked["runtime_only"] == "y" + + +def test_encrypt_tool_parameters(): + manager = _build_manager() + + with patch("core.tools.utils.configuration.encrypter.encrypt_token", return_value="enc"): + encrypted = manager.encrypt_tool_parameters({"secret": "raw", "plain": "x"}) + + assert encrypted["secret"] == "enc" + assert encrypted["plain"] == "x" + + +def test_decrypt_tool_parameters_cache_hit_and_miss(): + manager = _build_manager() + + with ( + patch.object(ToolParameterCache, "get", return_value={"secret": "cached"}), + patch.object(ToolParameterCache, "set") as mock_set, + ): + assert manager.decrypt_tool_parameters({"secret": "enc"}) == {"secret": "cached"} + mock_set.assert_not_called() + + with ( + patch.object(ToolParameterCache, "get", return_value=None), + patch.object(ToolParameterCache, "set") as mock_set, + patch("core.tools.utils.configuration.encrypter.decrypt_token", return_value="dec"), + ): + decrypted = manager.decrypt_tool_parameters({"secret": "enc", "plain": "x"}) + assert decrypted["secret"] == "dec" + mock_set.assert_called_once() + + +def test_delete_tool_parameters_cache(): + manager = _build_manager() + + with patch.object(ToolParameterCache, "delete") as mock_delete: + manager.delete_tool_parameters_cache() + + mock_delete.assert_called_once() + + +def test_configuration_manager_decrypt_suppresses_errors(): + manager = _build_manager() + with ( + patch.object(ToolParameterCache, "get", return_value=None), + patch("core.tools.utils.configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")), + ): + decrypted = manager.decrypt_tool_parameters({"secret": "enc"}) + # decryption failure is suppressed, original value is retained. + assert decrypted["secret"] == "enc" diff --git a/api/tests/unit_tests/core/tools/utils/test_encryption.py b/api/tests/unit_tests/core/tools/utils/test_encryption.py index 94be0bb573..ce77473dbd 100644 --- a/api/tests/unit_tests/core/tools/utils/test_encryption.py +++ b/api/tests/unit_tests/core/tools/utils/test_encryption.py @@ -1,10 +1,13 @@ import copy -from unittest.mock import patch +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, patch import pytest from core.entities.provider_entities import BasicProviderConfig from core.helper.provider_encryption import ProviderConfigEncrypter +from core.tools.utils.encryption import create_tool_provider_encrypter # --------------------------- @@ -13,13 +16,13 @@ from core.helper.provider_encryption import ProviderConfigEncrypter class NoopCache: """Simple cache stub: always returns None, does nothing for set/delete.""" - def get(self): + def get(self) -> Any | None: return None - def set(self, config): + def set(self, config: Any) -> None: pass - def delete(self): + def delete(self) -> None: pass @@ -179,3 +182,35 @@ def test_decrypt_swallow_exception_and_keep_original(encrypter_obj): out = encrypter_obj.decrypt({"password": "ENC_ERR"}) assert out["password"] == "ENC_ERR" + + +def test_create_tool_provider_encrypter_builds_cache_and_encrypter(): + basic_config = BasicProviderConfig(name="key", type=BasicProviderConfig.Type.TEXT_INPUT) + credential_schema_item = SimpleNamespace(to_basic_provider_config=lambda: basic_config) + controller = SimpleNamespace( + provider_type=SimpleNamespace(value="builtin"), + entity=SimpleNamespace(identity=SimpleNamespace(name="provider-a")), + get_credentials_schema=lambda: [credential_schema_item], + ) + + cache_instance = Mock() + encrypter_instance = Mock() + + with patch( + "core.tools.utils.encryption.SingletonProviderCredentialsCache", return_value=cache_instance + ) as cache_cls: + with patch("core.tools.utils.encryption.ProviderConfigEncrypter", return_value=encrypter_instance) as enc_cls: + encrypter, cache = create_tool_provider_encrypter("tenant-1", controller) + + assert encrypter is encrypter_instance + assert cache is cache_instance + cache_cls.assert_called_once_with( + tenant_id="tenant-1", + provider_type="builtin", + provider_identity="provider-a", + ) + enc_cls.assert_called_once_with( + tenant_id="tenant-1", + config=[basic_config], + provider_config_cache=cache_instance, + ) diff --git a/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py new file mode 100644 index 0000000000..4ce73272bf --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py @@ -0,0 +1,478 @@ +from __future__ import annotations + +import uuid +from contextlib import nullcontext +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest +from yaml import YAMLError + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.rag.models.document import Document as RagDocument +from core.tools.utils.dataset_retriever import dataset_multi_retriever_tool as multi_retriever_module +from core.tools.utils.dataset_retriever import dataset_retriever_tool as single_retriever_module +from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool +from core.tools.utils.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool as SingleDatasetRetrieverTool +from core.tools.utils.text_processing_utils import remove_leading_symbols +from core.tools.utils.uuid_utils import is_valid_uuid +from core.tools.utils.yaml_utils import _load_yaml_file, load_yaml_file_cached + + +def _retrieve_config() -> DatasetRetrieveConfigEntity: + return DatasetRetrieveConfigEntity(retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE) + + +class _FakeFlaskApp: + def app_context(self): + return nullcontext() + + +class _ImmediateThread: + def __init__(self, target=None, kwargs=None, **_kwargs): + self._target = target + self._kwargs = kwargs or {} + + def start(self): + if self._target is not None: + self._target(**self._kwargs) + + def join(self): + return None + + +class _TestHitCallback(DatasetIndexToolCallbackHandler): + def __init__(self): + self.queries: list[tuple[str, str]] = [] + self.documents: list[RagDocument] | None = None + self.resources = None + + def on_query(self, query: str, dataset_id: str): + self.queries.append((query, dataset_id)) + + def on_tool_end(self, documents: list[RagDocument]): + self.documents = documents + + def return_retriever_resource_info(self, resource): + self.resources = list(resource) + + +def test_remove_leading_symbols_preserves_markdown_link_and_strips_punctuation(): + markdown = "[Example](https://example.com) content" + assert remove_leading_symbols(markdown) == markdown + + assert remove_leading_symbols("...Hello world") == "Hello world" + + +def test_is_valid_uuid_handles_valid_invalid_and_empty_values(): + assert is_valid_uuid(str(uuid.uuid4())) is True + assert is_valid_uuid("not-a-uuid") is False + assert is_valid_uuid("") is False + assert is_valid_uuid(None) is False + + +def test_load_yaml_file_valid(tmp_path): + valid_file = tmp_path / "valid.yaml" + valid_file.write_text("a: 1\nb: two\n", encoding="utf-8") + + loaded = _load_yaml_file(file_path=str(valid_file)) + + assert loaded == {"a": 1, "b": "two"} + + +def test_load_yaml_file_missing(tmp_path): + with pytest.raises(FileNotFoundError): + _load_yaml_file(file_path=str(tmp_path / "missing.yaml")) + + +def test_load_yaml_file_invalid(tmp_path): + invalid_file = tmp_path / "invalid.yaml" + invalid_file.write_text("a: [1, 2\n", encoding="utf-8") + + with pytest.raises(YAMLError): + _load_yaml_file(file_path=str(invalid_file)) + + +def test_load_yaml_file_cached_hits(tmp_path): + valid_file = tmp_path / "valid.yaml" + valid_file.write_text("a: 1\nb: two\n", encoding="utf-8") + + load_yaml_file_cached.cache_clear() + assert load_yaml_file_cached(str(valid_file)) == {"a": 1, "b": "two"} + + assert load_yaml_file_cached(str(valid_file)) == {"a": 1, "b": "two"} + assert load_yaml_file_cached.cache_info().hits == 1 + + +def test_single_dataset_retriever_from_dataset_builds_name_and_description(): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1", name="Knowledge", description=None) + + tool = SingleDatasetRetrieverTool.from_dataset( + dataset=dataset, + retrieve_config=_retrieve_config(), + return_resource=False, + retriever_from="prod", + inputs={}, + ) + + assert tool.name == "dataset_dataset_1" + assert tool.description == "useful for when you want to answer queries about the Knowledge" + + +def test_single_dataset_retriever_external_run_returns_content_and_resources(): + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + name="Knowledge Base", + provider="external", + indexing_technique="high_quality", + retrieval_model={}, + ) + callback = _TestHitCallback() + dataset_retrieval = Mock() + dataset_retrieval.get_metadata_filter_condition.return_value = ( + {"dataset-1": ["doc-a"]}, + {"logical_operator": "and"}, + ) + db_session = Mock() + db_session.scalar.return_value = dataset + external_documents = [ + {"content": "first", "metadata": {"document_id": "doc-a"}, "score": 0.9, "title": "Doc A"}, + {"content": "second", "metadata": {"document_id": "doc-b"}, "score": 0.8, "title": "Doc B"}, + ] + + tool = SingleDatasetRetrieverTool( + tenant_id="tenant-1", + dataset_id="dataset-1", + retrieve_config=_retrieve_config(), + return_resource=True, + retriever_from="dev", + hit_callbacks=[callback], + inputs={"x": 1}, + ) + + with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval): + with patch.object( + single_retriever_module.ExternalDatasetService, + "fetch_external_knowledge_retrieval", + return_value=external_documents, + ) as fetch_mock: + result = tool.run(query="hello") + + assert result == "first\nsecond" + assert callback.queries == [("hello", "dataset-1")] + assert callback.resources is not None + resource_info = callback.resources + assert [item.position for item in resource_info] == [1, 2] + assert resource_info[0].dataset_id == "dataset-1" + fetch_mock.assert_called_once() + + +def test_single_dataset_retriever_returns_empty_when_metadata_filter_finds_no_documents(): + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + name="Knowledge Base", + provider="internal", + indexing_technique="high_quality", + retrieval_model=None, + ) + dataset_retrieval = Mock() + dataset_retrieval.get_metadata_filter_condition.return_value = ({"dataset-1": []}, {"logical_operator": "and"}) + db_session = Mock() + db_session.scalar.return_value = dataset + + tool = SingleDatasetRetrieverTool( + tenant_id="tenant-1", + dataset_id="dataset-1", + retrieve_config=_retrieve_config(), + return_resource=False, + retriever_from="prod", + hit_callbacks=[_TestHitCallback()], + inputs={}, + ) + + with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval): + with patch.object(single_retriever_module.RetrievalService, "retrieve") as retrieve_mock: + result = tool.run(query="hello") + + assert result == "" + retrieve_mock.assert_not_called() + + +def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources(): + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + name="Knowledge Base", + provider="internal", + indexing_technique="high_quality", + retrieval_model={ + "search_method": "semantic_search", + "score_threshold_enabled": True, + "score_threshold": 0.2, + "reranking_enable": True, + "reranking_model": {"reranking_provider_name": "provider", "reranking_model_name": "model"}, + "reranking_mode": "reranking_model", + "weights": {"vector_setting": {"vector_weight": 0.6}}, + }, + ) + callback = _TestHitCallback() + dataset_retrieval = Mock() + dataset_retrieval.get_metadata_filter_condition.return_value = (None, None) + low_segment = SimpleNamespace( + id="seg-low", + dataset_id="dataset-1", + document_id="doc-low", + content="raw low", + answer="low answer", + hit_count=1, + word_count=10, + position=3, + index_node_hash="hash-low", + get_sign_content=lambda: "signed low", + ) + high_segment = SimpleNamespace( + id="seg-high", + dataset_id="dataset-1", + document_id="doc-high", + content="raw high", + answer=None, + hit_count=9, + word_count=25, + position=1, + index_node_hash="hash-high", + get_sign_content=lambda: "signed high", + ) + records = [ + SimpleNamespace(segment=low_segment, score=0.2, summary="summary low"), + SimpleNamespace(segment=high_segment, score=0.9, summary=None), + ] + documents = [ + RagDocument(page_content="first", metadata={"doc_id": "node-low", "score": 0.2}), + RagDocument(page_content="second", metadata={"doc_id": "node-high", "score": 0.9}), + ] + lookup_doc_low = SimpleNamespace( + id="doc-low", name="Document Low", data_source_type="upload_file", doc_metadata={"lang": "en"} + ) + lookup_doc_high = SimpleNamespace( + id="doc-high", name="Document High", data_source_type="notion", doc_metadata={"lang": "fr"} + ) + db_session = Mock() + db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high] + db_session.query.return_value.filter_by.return_value.first.return_value = dataset + + tool = SingleDatasetRetrieverTool( + tenant_id="tenant-1", + dataset_id="dataset-1", + retrieve_config=_retrieve_config(), + return_resource=True, + retriever_from="dev", + hit_callbacks=[callback], + inputs={}, + top_k=2, + ) + + with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval): + with patch.object(single_retriever_module.RetrievalService, "retrieve", return_value=documents): + with patch.object( + single_retriever_module.RetrievalService, + "format_retrieval_documents", + return_value=records, + ): + result = tool.run(query="hello") + + assert result == "signed high\nsummary low\nquestion:signed low answer:low answer" + assert callback.documents == documents + assert callback.resources is not None + resource_info = callback.resources + assert [item.position for item in resource_info] == [1, 2] + assert resource_info[0].segment_id == "seg-high" + assert resource_info[0].hit_count == 9 + assert resource_info[1].summary == "summary low" + assert resource_info[1].content == "question:raw low \nanswer:low answer" + + +def test_multi_dataset_retriever_from_dataset_sets_tool_name(): + tool = DatasetMultiRetrieverTool.from_dataset( + dataset_ids=["dataset-1"], + tenant_id="tenant-1", + reranking_provider_name="provider", + reranking_model_name="model", + return_resource=False, + retriever_from="prod", + ) + + assert tool.name == "dataset_tenant_1" + + +def test_multi_dataset_retriever_retriever_returns_early_when_dataset_is_missing(): + callback = _TestHitCallback() + all_documents: list[RagDocument] = [] + db_session = Mock() + db_session.scalar.return_value = None + tool = DatasetMultiRetrieverTool( + tenant_id="tenant-1", + dataset_ids=["dataset-1"], + reranking_provider_name="provider", + reranking_model_name="model", + return_resource=False, + retriever_from="prod", + ) + + with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(multi_retriever_module.RetrievalService, "retrieve") as retrieve_mock: + result = tool._retriever( + flask_app=_FakeFlaskApp(), + dataset_id="dataset-1", + query="hello", + all_documents=all_documents, + hit_callbacks=[callback], + ) + + assert result == [] + assert all_documents == [] + assert callback.queries == [] + retrieve_mock.assert_not_called() + + +def test_multi_dataset_retriever_retriever_non_economy_uses_retrieval_model(): + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + indexing_technique="high_quality", + retrieval_model={ + "search_method": "semantic_search", + "top_k": 6, + "score_threshold_enabled": True, + "score_threshold": 0.4, + "reranking_enable": False, + "reranking_mode": None, + "weights": {"balanced": True}, + }, + ) + callback = _TestHitCallback() + documents = [RagDocument(page_content="retrieved", metadata={"doc_id": "node-1", "score": 0.4})] + all_documents: list[RagDocument] = [] + db_session = Mock() + db_session.scalar.return_value = dataset + tool = DatasetMultiRetrieverTool( + tenant_id="tenant-1", + dataset_ids=["dataset-1"], + reranking_provider_name="provider", + reranking_model_name="model", + return_resource=False, + retriever_from="prod", + top_k=2, + ) + + with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(multi_retriever_module.RetrievalService, "retrieve", return_value=documents) as retrieve_mock: + tool._retriever( + flask_app=_FakeFlaskApp(), + dataset_id="dataset-1", + query="hello", + all_documents=all_documents, + hit_callbacks=[callback], + ) + + assert all_documents == documents + assert callback.queries == [("hello", "dataset-1")] + retrieve_mock.assert_called_once_with( + retrieval_method="semantic_search", + dataset_id="dataset-1", + query="hello", + top_k=6, + score_threshold=0.4, + reranking_model=None, + reranking_mode="reranking_model", + weights={"balanced": True}, + ) + + +def test_multi_dataset_retriever_run_orders_segments_and_returns_resources(): + callback = _TestHitCallback() + tool = DatasetMultiRetrieverTool( + tenant_id="tenant-1", + dataset_ids=["dataset-1", "dataset-2"], + reranking_provider_name="provider", + reranking_model_name="model", + return_resource=True, + retriever_from="dev", + hit_callbacks=[callback], + top_k=2, + score_threshold=0.1, + ) + first_doc = RagDocument(page_content="first", metadata={"doc_id": "node-2", "score": 0.4}) + second_doc = RagDocument(page_content="second", metadata={"doc_id": "node-1", "score": 0.9}) + + def fake_retriever(**kwargs): + if kwargs["dataset_id"] == "dataset-1": + kwargs["all_documents"].append(first_doc) + else: + kwargs["all_documents"].append(second_doc) + + segment_for_node_2 = SimpleNamespace( + id="seg-2", + dataset_id="dataset-1", + document_id="doc-2", + index_node_id="node-2", + content="raw two", + answer="answer two", + hit_count=2, + word_count=20, + position=2, + index_node_hash="hash-2", + get_sign_content=lambda: "signed two", + ) + segment_for_node_1 = SimpleNamespace( + id="seg-1", + dataset_id="dataset-2", + document_id="doc-1", + index_node_id="node-1", + content="raw one", + answer=None, + hit_count=7, + word_count=30, + position=1, + index_node_hash="hash-1", + get_sign_content=lambda: "signed one", + ) + db_session = Mock() + db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1] + db_session.query.return_value.filter_by.return_value.first.side_effect = [ + SimpleNamespace(id="dataset-2", name="Dataset Two"), + SimpleNamespace(id="dataset-1", name="Dataset One"), + ] + db_session.scalar.side_effect = [ + SimpleNamespace(id="doc-1", name="Doc One", data_source_type="upload_file", doc_metadata={"p": 1}), + SimpleNamespace(id="doc-2", name="Doc Two", data_source_type="notion", doc_metadata={"p": 2}), + ] + model_manager = Mock() + model_manager.get_model_instance.return_value = Mock() + rerank_runner = Mock() + rerank_runner.run.return_value = [second_doc, first_doc] + fake_current_app = SimpleNamespace(_get_current_object=lambda: _FakeFlaskApp()) + + with patch.object(tool, "_retriever", side_effect=fake_retriever) as retriever_mock: + with patch.object(multi_retriever_module, "current_app", fake_current_app): + with patch.object(multi_retriever_module.threading, "Thread", _ImmediateThread): + with patch.object(multi_retriever_module, "ModelManager", return_value=model_manager): + with patch.object(multi_retriever_module, "RerankModelRunner", return_value=rerank_runner): + with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)): + result = tool.run(query="hello") + + assert result == "signed one\nquestion:signed two answer:answer two" + assert retriever_mock.call_count == 2 + assert callback.documents == [second_doc, first_doc] + assert callback.resources is not None + resource_info = callback.resources + assert [item.position for item in resource_info] == [1, 2] + assert resource_info[0].score == 0.9 + assert resource_info[0].content == "raw one" + assert resource_info[1].score == 0.4 + assert resource_info[1].content == "question:raw two \nanswer:answer two" diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py new file mode 100644 index 0000000000..2acae889b2 --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -0,0 +1,158 @@ +"""Unit tests for ModelInvocationUtils. + +Covers success and error branches for ModelInvocationUtils, including +InvokeModelError and invoke error mappings for InvokeAuthorizationError, +InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, and +InvokeServerUnavailableError. Assumes mocked model instances and managers. +""" + +from __future__ import annotations + +from decimal import Decimal +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace: + model_type_instance = Mock() + model_type_instance.get_model_schema.return_value = ( + SimpleNamespace(model_properties=schema or {}) if schema is not None else None + ) + return SimpleNamespace( + provider="provider", + model="model-a", + model_name="model-a", + credentials={"api_key": "x"}, + model_type_instance=model_type_instance, + get_llm_num_tokens=lambda prompt_messages: 5, + invoke_llm=Mock(), + ) + + +@pytest.mark.parametrize( + ("model_instance", "expected", "error_match"), + [ + (None, None, "Model not found"), + (_mock_model_instance(schema=None), None, "No model schema found"), + (_mock_model_instance(schema={}), 2048, None), + (_mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 8192}), 8192, None), + ], + ids=[ + "missing-model", + "missing-schema", + "default-context-size", + "schema-context-size", + ], +) +def test_get_max_llm_context_tokens_branches(model_instance, expected, error_match): + manager = Mock() + manager.get_default_model_instance.return_value = model_instance + + with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + if error_match: + with pytest.raises(InvokeModelError, match=error_match): + ModelInvocationUtils.get_max_llm_context_tokens("tenant") + else: + assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected + + +def test_calculate_tokens_handles_missing_model(): + manager = Mock() + manager.get_default_model_instance.return_value = None + with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with pytest.raises(InvokeModelError, match="Model not found"): + ModelInvocationUtils.calculate_tokens("tenant", []) + + +def test_invoke_success_and_error_mappings(): + model_instance = _mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 2048}) + model_instance.invoke_llm.return_value = SimpleNamespace( + message=SimpleNamespace(content="ok"), + usage=SimpleNamespace( + completion_tokens=7, + completion_unit_price=Decimal("0.1"), + completion_price_unit=Decimal(1), + latency=0.3, + total_price=Decimal("0.7"), + currency="USD", + ), + ) + manager = Mock() + manager.get_default_model_instance.return_value = model_instance + + class _ToolModelInvoke: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + db_mock = SimpleNamespace(session=Mock()) + + with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): + with patch("core.tools.utils.model_invocation_utils.db", db_mock): + response = ModelInvocationUtils.invoke( + user_id="u1", + tenant_id="tenant", + tool_type="builtin", + tool_name="tool-a", + prompt_messages=[], + ) + + assert response.message.content == "ok" + assert db_mock.session.add.call_count == 1 + assert db_mock.session.commit.call_count == 2 + + +@pytest.mark.parametrize( + ("exc", "expected"), + [ + (InvokeRateLimitError("rate"), "Invoke rate limit error"), + (InvokeBadRequestError("bad"), "Invoke bad request error"), + (InvokeConnectionError("conn"), "Invoke connection error"), + (InvokeAuthorizationError("auth"), "Invoke authorization error"), + (InvokeServerUnavailableError("down"), "Invoke server unavailable error"), + (RuntimeError("oops"), "Invoke error"), + ], + ids=[ + "rate-limit", + "bad-request", + "connection", + "authorization", + "server-unavailable", + "generic-error", + ], +) +def test_invoke_error_mappings(exc, expected): + model_instance = _mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 2048}) + model_instance.invoke_llm.side_effect = exc + manager = Mock() + manager.get_default_model_instance.return_value = model_instance + + class _ToolModelInvoke: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + db_mock = SimpleNamespace(session=Mock()) + + with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): + with patch("core.tools.utils.model_invocation_utils.db", db_mock): + with pytest.raises(InvokeModelError, match=expected): + ModelInvocationUtils.invoke( + user_id="u1", + tenant_id="tenant", + tool_type="builtin", + tool_name="tool-a", + prompt_messages=[], + ) diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index f39158aa59..40f91b12a0 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -1,6 +1,12 @@ +from json.decoder import JSONDecodeError +from unittest.mock import Mock, patch + import pytest from flask import Flask +from yaml import YAMLError +from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter +from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError from core.tools.utils.parser import ApiBasedToolSchemaParser @@ -189,3 +195,225 @@ def test_parse_openapi_to_tool_bundle_default_value_type_casting(app): available_param = params_by_name["available"] assert available_param.type == "boolean" assert available_param.default is True + + +def test_sanitize_default_value_and_type_detection(): + assert ApiBasedToolSchemaParser._sanitize_default_value([]) is None + assert ApiBasedToolSchemaParser._sanitize_default_value({}) is None + assert ApiBasedToolSchemaParser._sanitize_default_value("ok") == "ok" + + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"format": "binary"}) == ToolParameter.ToolParameterType.FILE + ) + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "integer"}) == ToolParameter.ToolParameterType.NUMBER + ) + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"schema": {"type": "boolean"}}) + == ToolParameter.ToolParameterType.BOOLEAN + ) + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "array", "items": {"format": "binary"}}) + == ToolParameter.ToolParameterType.FILES + ) + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "array", "items": {"type": "string"}}) + == ToolParameter.ToolParameterType.ARRAY + ) + assert ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "object"}) is None + + +def test_parse_openapi_to_tool_bundle_server_env_and_refs(app): + openapi = { + "openapi": "3.0.0", + "info": {"title": "API", "version": "1.0.0", "description": "API description"}, + "servers": [ + {"url": "https://dev.example.com", "env": "dev"}, + {"url": "https://prod.example.com", "env": "prod"}, + ], + "paths": { + "/items": { + "post": { + "description": "Create item", + "parameters": [ + {"$ref": "#/components/parameters/token"}, + {"name": "token", "schema": {"type": "string"}}, + ], + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/ItemRequest"}}} + }, + } + } + }, + "components": { + "parameters": { + "token": {"name": "token", "required": True, "schema": {"type": "string"}}, + }, + "schemas": { + "ItemRequest": { + "type": "object", + "required": ["age"], + "properties": {"age": {"type": "integer", "description": "Age", "default": 18}}, + } + }, + }, + } + + extra_info: dict = {} + warning: dict = {} + with app.test_request_context(headers={"X-Request-Env": "prod"}): + bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) + + assert len(bundles) == 1 + assert bundles[0].server_url == "https://prod.example.com/items" + assert warning["duplicated_parameter"].startswith("Parameter token") + assert extra_info["description"] == "API description" + + +def test_parse_openapi_to_tool_bundle_no_server_raises(app): + openapi = {"info": {"title": "x"}, "servers": [], "paths": {}} + with app.test_request_context(): + with pytest.raises(ToolProviderNotFoundError, match="No server found"): + ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi) + + +def test_parse_openapi_yaml_to_tool_bundle_invalid_yaml(app): + with app.test_request_context(): + with pytest.raises(ToolApiSchemaError, match="Invalid openapi yaml"): + ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle("null") + + +def test_parse_swagger_to_openapi_branches(): + with pytest.raises(ToolApiSchemaError, match="No server found"): + ApiBasedToolSchemaParser.parse_swagger_to_openapi({"info": {}, "paths": {}}) + + with pytest.raises(ToolApiSchemaError, match="No paths found"): + ApiBasedToolSchemaParser.parse_swagger_to_openapi({"servers": [{"url": "https://x"}], "paths": {}}) + + with pytest.raises(ToolApiSchemaError, match="No operationId found"): + ApiBasedToolSchemaParser.parse_swagger_to_openapi( + { + "servers": [{"url": "https://x"}], + "paths": {"/a": {"get": {"summary": "x", "responses": {}}}}, + } + ) + + warning: dict = {"seed": True} + converted = ApiBasedToolSchemaParser.parse_swagger_to_openapi( + { + "servers": [{"url": "https://x"}], + "paths": {"/a": {"get": {"operationId": "getA", "responses": {}}}}, + "definitions": {"A": {"type": "object"}}, + }, + warning=warning, + ) + assert converted["openapi"] == "3.0.0" + assert converted["components"]["schemas"]["A"]["type"] == "object" + assert warning["missing_summary"].startswith("No summary or description found") + + +def test_parse_openai_plugin_json_branches(app): + with app.test_request_context(): + with pytest.raises(ToolProviderNotFoundError, match="Invalid openai plugin json"): + ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle("{bad") + + with pytest.raises(ToolNotSupportedError, match="Only openapi is supported"): + ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + '{"api": {"url": "https://x", "type": "graphql"}}' + ) + + +def test_parse_openai_plugin_json_http_branches(app): + with app.test_request_context(): + response = type("Resp", (), {"status_code": 500, "text": "", "close": Mock()})() + with patch("core.tools.utils.parser.httpx.get", return_value=response): + with pytest.raises(ToolProviderNotFoundError, match="cannot get openapi yaml"): + ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + '{"api": {"url": "https://x", "type": "openapi"}}' + ) + response.close.assert_called_once() + + success_response = type("Resp", (), {"status_code": 200, "text": "openapi: 3.0.0", "close": Mock()})() + with patch("core.tools.utils.parser.httpx.get", return_value=success_response): + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle", + return_value=["bundle"], + ) as mock_parse: + bundles = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + '{"api": {"url": "https://x", "type": "openapi"}}' + ) + assert bundles == ["bundle"] + mock_parse.assert_called_once() + success_response.close.assert_called_once() + + +def test_auto_parse_json_yaml_failure(): + with patch("core.tools.utils.parser.json_loads", side_effect=JSONDecodeError("bad", "x", 0)): + with patch("core.tools.utils.parser.safe_load", side_effect=YAMLError("bad yaml")): + with pytest.raises(ToolApiSchemaError, match="Invalid api schema, schema is neither json nor yaml"): + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(":::") + + +def test_auto_parse_openapi_success(): + openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}' + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle", + return_value=["openapi-bundle"], + ): + bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content) + + assert bundles == ["openapi-bundle"] + assert schema_type == ApiProviderSchemaType.OPENAPI + + +def test_auto_parse_openapi_then_swagger(): + openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}' + loaded_content = { + "openapi": "3.0.0", + "servers": [{"url": "https://x"}], + "info": {"title": "x"}, + "paths": {}, + } + converted_swagger = { + "openapi": "3.0.0", + "servers": [{"url": "https://x"}], + "info": {"title": "x"}, + "paths": {}, + } + + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle", + side_effect=[ToolApiSchemaError("openapi error"), ["swagger-bundle"]], + ) as mock_parse_openapi: + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_swagger_to_openapi", + return_value=converted_swagger, + ) as mock_parse_swagger: + bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content) + + assert bundles == ["swagger-bundle"] + assert schema_type == ApiProviderSchemaType.SWAGGER + mock_parse_swagger.assert_called_once_with(loaded_content, extra_info={}, warning={}) + assert mock_parse_openapi.call_count == 2 + mock_parse_openapi.assert_any_call(loaded_content, extra_info={}, warning={}) + mock_parse_openapi.assert_any_call(converted_swagger, extra_info={}, warning={}) + + +def test_auto_parse_openapi_swagger_then_plugin(): + openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}' + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle", + side_effect=ToolApiSchemaError("openapi error"), + ): + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_swagger_to_openapi", + side_effect=ToolApiSchemaError("swagger error"), + ): + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle", + return_value=["plugin-bundle"], + ): + bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content) + + assert bundles == ["plugin-bundle"] + assert schema_type == ApiProviderSchemaType.OPENAI_PLUGIN diff --git a/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py new file mode 100644 index 0000000000..5691f33e65 --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import pytest + +from core.tools.utils import system_oauth_encryption as oauth_encryption +from core.tools.utils.system_oauth_encryption import OAuthEncryptionError, SystemOAuthEncrypter + + +def test_system_oauth_encrypter_roundtrip(): + encrypter = SystemOAuthEncrypter(secret_key="test-secret") + payload = {"client_id": "cid", "client_secret": "csecret", "grant_type": "authorization_code"} + + encrypted = encrypter.encrypt_oauth_params(payload) + decrypted = encrypter.decrypt_oauth_params(encrypted) + + assert encrypted + assert dict(decrypted) == payload + + +def test_system_oauth_encrypter_decrypt_validates_input(): + encrypter = SystemOAuthEncrypter(secret_key="test-secret") + + with pytest.raises(ValueError, match="must be a string"): + encrypter.decrypt_oauth_params(123) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="cannot be empty"): + encrypter.decrypt_oauth_params("") + + +def test_system_oauth_encrypter_raises_oauth_error_for_invalid_ciphertext(): + encrypter = SystemOAuthEncrypter(secret_key="test-secret") + + with pytest.raises(OAuthEncryptionError, match="Decryption failed"): + encrypter.decrypt_oauth_params("not-base64") + + +def test_system_oauth_helpers_use_global_cached_instance(monkeypatch): + monkeypatch.setattr(oauth_encryption, "_oauth_encrypter", None) + monkeypatch.setattr("core.tools.utils.system_oauth_encryption.dify_config.SECRET_KEY", "global-secret") + + first = oauth_encryption.get_system_oauth_encrypter() + second = oauth_encryption.get_system_oauth_encrypter() + assert first is second + + encrypted = oauth_encryption.encrypt_system_oauth_params({"k": "v"}) + assert oauth_encryption.decrypt_system_oauth_params(encrypted) == {"k": "v"} + + +def test_create_system_oauth_encrypter_factory(): + encrypter = oauth_encryption.create_system_oauth_encrypter(secret_key="factory-secret") + assert isinstance(encrypter, SystemOAuthEncrypter) diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index c46e31d90f..dd79b79718 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -1,7 +1,9 @@ import pytest +from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): @@ -31,3 +33,91 @@ def test_ensure_no_human_input_nodes_raises_for_human_input(): WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph) assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" + + +def test_get_workflow_graph_variables_and_outputs(): + graph = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "variables": [ + { + "variable": "query", + "label": "Query", + "type": "text-input", + "required": True, + } + ], + }, + }, + { + "id": "end-1", + "data": { + "type": "end", + "outputs": [ + {"variable": "answer", "value_type": "string", "value_selector": ["n1", "answer"]}, + {"variable": "score", "value_type": "number", "value_selector": ["n1", "score"]}, + ], + }, + }, + { + "id": "end-2", + "data": { + "type": "end", + "outputs": [ + {"variable": "answer", "value_type": "object", "value_selector": ["n2", "answer"]}, + ], + }, + }, + ] + } + + variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) + assert len(variables) == 1 + assert variables[0].variable == "query" + assert variables[0].type == VariableEntityType.TEXT_INPUT + + outputs = WorkflowToolConfigurationUtils.get_workflow_graph_output(graph) + assert [output.variable for output in outputs] == ["answer", "score"] + assert outputs[0].value_type == "object" + assert outputs[1].value_type == "number" + + no_start = WorkflowToolConfigurationUtils.get_workflow_graph_variables({"nodes": []}) + assert no_start == [] + + +def test_check_is_synced_validation(): + variables = [ + VariableEntity( + variable="query", + label="Query", + type=VariableEntityType.TEXT_INPUT, + required=True, + ) + ] + configs = [ + WorkflowToolParameterConfiguration( + name="query", + description="desc", + form=ToolParameter.ToolParameterForm.FORM, + ) + ] + + WorkflowToolConfigurationUtils.check_is_synced(variables=variables, tool_configurations=configs) + + with pytest.raises(ValueError, match="parameter configuration mismatch"): + WorkflowToolConfigurationUtils.check_is_synced(variables=variables, tool_configurations=[]) + + with pytest.raises(ValueError, match="parameter configuration mismatch"): + WorkflowToolConfigurationUtils.check_is_synced( + variables=variables, + tool_configurations=[ + WorkflowToolParameterConfiguration( + name="other", + description="desc", + form=ToolParameter.ToolParameterForm.FORM, + ) + ], + ) diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py new file mode 100644 index 0000000000..dd140cbb27 --- /dev/null +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolParameter, + ToolProviderEntity, + ToolProviderIdentity, + ToolProviderType, +) +from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType + + +def _controller() -> WorkflowToolProviderController: + entity = ToolProviderEntity( + identity=ToolProviderIdentity( + author="author", + name="wf-provider", + description=I18nObject(en_US="desc"), + icon="icon.svg", + label=I18nObject(en_US="WF"), + ), + credentials_schema=[], + ) + return WorkflowToolProviderController(entity=entity, provider_id="provider-1") + + +def _mock_session_with_begin() -> Mock: + session = Mock() + begin_cm = Mock() + begin_cm.__enter__ = Mock(return_value=None) + begin_cm.__exit__ = Mock(return_value=False) + session.begin.return_value = begin_cm + return session + + +def test_get_db_provider_tool_builds_entity(): + controller = _controller() + session = Mock() + workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={}) + session.query.return_value.where.return_value.first.return_value = workflow + app = SimpleNamespace(id="app-1") + db_provider = SimpleNamespace( + id="provider-1", + app_id="app-1", + version="1", + label="WF Provider", + description="desc", + icon="icon.svg", + name="workflow_tool", + tenant_id="tenant-1", + user_id="user-1", + parameter_configurations=[ + SimpleNamespace(name="country", description="Country", form=ToolParameter.ToolParameterForm.FORM), + SimpleNamespace(name="files", description="files", form=ToolParameter.ToolParameterForm.FORM), + ], + ) + user = SimpleNamespace(name="Alice") + variables = [ + VariableEntity( + variable="country", + label="Country", + description="Country", + type=VariableEntityType.SELECT, + required=True, + options=["US", "IN"], + ) + ] + outputs = [ + SimpleNamespace(variable="json", value_type="string"), + SimpleNamespace(variable="answer", value_type="string"), + ] + + with ( + patch( + "core.tools.workflow_as_tool.provider.WorkflowAppConfigManager.convert_features", + return_value=SimpleNamespace(file_upload=True), + ), + patch( + "core.tools.workflow_as_tool.provider.WorkflowToolConfigurationUtils.get_workflow_graph_variables", + return_value=variables, + ), + patch( + "core.tools.workflow_as_tool.provider.WorkflowToolConfigurationUtils.get_workflow_graph_output", + return_value=outputs, + ), + ): + tool = controller._get_db_provider_tool(db_provider, app, session=session, user=user) + + assert tool.entity.identity.name == "workflow_tool" + # "json" output is reserved for ToolInvokeMessage.VariableMessage and filtered out. + assert tool.entity.output_schema["properties"] == {"answer": {"type": "string", "description": ""}} + assert "json" not in tool.entity.output_schema["properties"] + assert tool.entity.parameters[0].type == ToolParameter.ToolParameterType.SELECT + assert tool.entity.parameters[1].type == ToolParameter.ToolParameterType.SYSTEM_FILES + assert controller.provider_type == ToolProviderType.WORKFLOW + + +def test_get_tool_returns_hit_or_none(): + controller = _controller() + tool = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="workflow_tool"))) + controller.tools = [tool] + + assert controller.get_tool("workflow_tool") is tool + assert controller.get_tool("missing") is None + + +def test_get_tools_returns_cached(): + controller = _controller() + cached_tools = [SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf-cached")))] + controller.tools = cached_tools # type: ignore[assignment] + + assert controller.get_tools("tenant-1") == cached_tools + + +def test_from_db_builds_controller(): + controller = _controller() + + app = SimpleNamespace(id="app-1") + user = SimpleNamespace(name="Alice") + db_provider = SimpleNamespace( + id="provider-1", + app_id="app-1", + version="1", + user_id="user-1", + label="WF Provider", + description="desc", + icon="icon.svg", + name="workflow_tool", + tenant_id="tenant-1", + parameter_configurations=[], + ) + session = _mock_session_with_begin() + session.query.return_value.where.return_value.first.return_value = db_provider + session.get.side_effect = [app, user] + fake_cm = MagicMock() + fake_cm.__enter__.return_value = session + fake_cm.__exit__.return_value = False + fake_session_factory = Mock() + fake_session_factory.create_session.return_value = fake_cm + + with patch("core.tools.workflow_as_tool.provider.session_factory", fake_session_factory): + with patch.object( + WorkflowToolProviderController, + "_get_db_provider_tool", + return_value=SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf"))), + ): + built = WorkflowToolProviderController.from_db(db_provider) + assert isinstance(built, WorkflowToolProviderController) + assert built.tools + + +def test_get_tools_returns_empty_when_provider_missing(): + controller = _controller() + controller.tools = None # type: ignore[assignment] + + with patch("core.tools.workflow_as_tool.provider.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.workflow_as_tool.provider.Session") as session_cls: + session = _mock_session_with_begin() + session.query.return_value.where.return_value.first.return_value = None + session_cls.return_value.__enter__.return_value = session + + assert controller.get_tools("tenant-1") == [] + + +def test_get_tools_raises_when_app_missing(): + controller = _controller() + controller.tools = None # type: ignore[assignment] + db_provider = SimpleNamespace( + id="provider-1", + app_id="app-1", + version="1", + user_id="user-1", + label="WF Provider", + description="desc", + icon="icon.svg", + name="workflow_tool", + tenant_id="tenant-1", + parameter_configurations=[], + ) + + with patch("core.tools.workflow_as_tool.provider.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.workflow_as_tool.provider.Session") as session_cls: + session = _mock_session_with_begin() + session.query.return_value.where.return_value.first.return_value = db_provider + session.get.return_value = None + session_cls.return_value.__enter__.return_value = session + with pytest.raises(ValueError, match="app not found"): + controller.get_tools("tenant-1") diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 36fdb0218c..cc00f79698 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -1,20 +1,85 @@ +"""Unit tests for workflow-as-tool behavior. + +StubSession/StubScalars emulate SQLAlchemy session/scalars with minimal methods +(`scalar`, `scalars`, `expunge`, `commit`, `refresh`, context manager) to keep +database access mocked and predictable in tests. +""" + +import json from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, Mock, patch import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool +from dify_graph.file import FILE_MODEL_IDENTITY -def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch): - """Ensure that WorkflowTool will throw a `ToolInvokeError` exception when - `WorkflowAppGenerator.generate` returns a result with `error` key inside - the `data` element. - """ +class StubScalars: + """Minimal stub for SQLAlchemy scalar results.""" + + _value: Any + + def __init__(self, value: Any) -> None: + self._value = value + + def first(self) -> Any: + return self._value + + +class StubSession: + """Minimal stub for session_factory-created sessions.""" + + scalar_results: list[Any] + scalars_results: list[Any] + expunge_calls: list[object] + + def __init__(self, *, scalar_results: list[Any] | None = None, scalars_results: list[Any] | None = None) -> None: + self.scalar_results = list(scalar_results or []) + self.scalars_results = list(scalars_results or []) + self.expunge_calls: list[object] = [] + + def scalar(self, _stmt: Any) -> Any: + return self.scalar_results.pop(0) + + def scalars(self, _stmt: Any) -> StubScalars: + return StubScalars(self.scalars_results.pop(0)) + + def expunge(self, value: Any) -> None: + self.expunge_calls.append(value) + + def begin(self) -> "StubSession": + return self + + def commit(self) -> None: + pass + + def refresh(self, _value: Any) -> None: + pass + + def close(self) -> None: + pass + + def __enter__(self) -> "StubSession": + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + +def _build_tool() -> WorkflowTool: entity = ToolEntity( identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), parameters=[], @@ -22,9 +87,9 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel has_runtime_parameters=False, ) runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", + return WorkflowTool( + workflow_app_id="app-1", + workflow_as_tool_id="wf-tool-1", version="1", workflow_entities={}, workflow_call_depth=1, @@ -32,13 +97,19 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel runtime=runtime, ) + +def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch): + """Ensure that WorkflowTool will throw a `ToolInvokeError` exception when + `WorkflowAppGenerator.generate` returns a result with `error` key inside + the `data` element. + """ + tool = _build_tool() + # needs to patch those methods to avoid database access. monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) # Mock user resolution to avoid database access - from unittest.mock import Mock - mock_user = Mock() monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) @@ -56,28 +127,12 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.MonkeyPatch): - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + """Ensure pause_state_config is passed as None.""" + tool = _build_tool() monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) - from unittest.mock import MagicMock, Mock - mock_user = Mock() monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) @@ -94,22 +149,7 @@ def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.Monke def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch): """Test that WorkflowTool should generate variable messages when there are outputs""" - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + tool = _build_tool() # Mock workflow outputs mock_outputs = {"result": "success", "count": 42, "data": {"key": "value"}} @@ -119,8 +159,6 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) # Mock user resolution to avoid database access - from unittest.mock import Mock - mock_user = Mock() monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) @@ -134,10 +172,6 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch # Execute tool invocation messages = list(tool.invoke("test_user", {})) - # Verify generated messages - # Should contain: 3 variable messages + 1 text message + 1 JSON message = 5 messages - assert len(messages) == 5 - # Verify variable messages variable_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.VARIABLE] assert len(variable_messages) == 3 @@ -151,7 +185,7 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch # Verify text message text_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.TEXT] assert len(text_messages) == 1 - assert '{"result": "success", "count": 42, "data": {"key": "value"}}' in text_messages[0].message.text + assert json.loads(text_messages[0].message.text) == mock_outputs # Verify JSON message json_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.JSON] @@ -161,30 +195,13 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch def test_workflow_tool_should_handle_empty_outputs(monkeypatch: pytest.MonkeyPatch): """Test that WorkflowTool should handle empty outputs correctly""" - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + tool = _build_tool() # needs to patch those methods to avoid database access. monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) # Mock user resolution to avoid database access - from unittest.mock import Mock - mock_user = Mock() monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) @@ -217,61 +234,32 @@ def test_workflow_tool_should_handle_empty_outputs(monkeypatch: pytest.MonkeyPat assert json_messages[0].message.json_object == {} -def test_create_variable_message(): - """Test the functionality of creating variable messages""" - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) - - # Test different types of variable values - test_cases = [ +@pytest.mark.parametrize( + ("var_name", "var_value"), + [ ("string_var", "test string"), ("int_var", 42), ("float_var", 3.14), ("bool_var", True), ("list_var", [1, 2, 3]), ("dict_var", {"key": "value"}), - ] + ], +) +def test_create_variable_message(var_name, var_value): + """Create variable messages for multiple value types.""" + tool = _build_tool() - for var_name, var_value in test_cases: - message = tool.create_variable_message(var_name, var_value) + message = tool.create_variable_message(var_name, var_value) - assert message.type == ToolInvokeMessage.MessageType.VARIABLE - assert message.message.variable_name == var_name - assert message.message.variable_value == var_value - assert message.message.stream is False + assert message.type == ToolInvokeMessage.MessageType.VARIABLE + assert message.message.variable_name == var_name + assert message.message.variable_value == var_value + assert message.message.stream is False def test_create_file_message_should_include_file_marker(): - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + """Ensure file message includes marker and meta payload.""" + tool = _build_tool() file_obj = object() message = tool.create_file_message(file_obj) # type: ignore[arg-type] @@ -284,103 +272,247 @@ def test_create_file_message_should_include_file_marker(): def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.MonkeyPatch): """Ensure worker context can resolve EndUser when Account is missing.""" - class StubSession: - def __init__(self, results: list): - self.results = results - - def scalar(self, _stmt): - return self.results.pop(0) - - # SQLAlchemy Session APIs used by code under test - def expunge(self, *_args, **_kwargs): - pass - - def close(self): - pass - - # support `with session_factory.create_session() as session:` - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - self.close() - tenant = SimpleNamespace(id="tenant_id") end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id") # Monkeypatch session factory to return our stub session + stub_session = StubSession(scalar_results=[tenant, None, end_user]) monkeypatch.setattr( "core.tools.workflow_as_tool.tool.session_factory.create_session", - lambda: StubSession([tenant, None, end_user]), + lambda: stub_session, ) - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="tenant_id", invoke_from=InvokeFrom.SERVICE_API) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + tool = _build_tool() + tool.runtime.invoke_from = InvokeFrom.SERVICE_API + tool.runtime.tenant_id = "tenant_id" resolved_user = tool._resolve_user_from_database(user_id=end_user.id) assert resolved_user is end_user + assert stub_session.expunge_calls == [end_user] def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pytest.MonkeyPatch): """Return None if tenant cannot be found in worker context.""" - class StubSession: - def __init__(self, results: list): - self.results = results - - def scalar(self, _stmt): - return self.results.pop(0) - - def expunge(self, *_args, **_kwargs): - pass - - def close(self): - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - self.close() - # Monkeypatch session factory to return our stub session with no tenant monkeypatch.setattr( "core.tools.workflow_as_tool.tool.session_factory.create_session", - lambda: StubSession([None]), + lambda: StubSession(scalar_results=[None]), ) - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="missing_tenant", invoke_from=InvokeFrom.SERVICE_API) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + tool = _build_tool() + tool.runtime.invoke_from = InvokeFrom.SERVICE_API + tool.runtime.tenant_id = "missing_tenant" resolved_user = tool._resolve_user_from_database(user_id="any") assert resolved_user is None + + +def test_workflow_tool_provider_type_and_fork_runtime(): + """Verify provider type and forked runtime behavior.""" + tool = _build_tool() + assert tool.tool_provider_type() == ToolProviderType.WORKFLOW + assert tool.latest_usage.total_tokens == 0 + + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2", invoke_from=InvokeFrom.DEBUGGER)) + assert isinstance(forked, WorkflowTool) + assert forked.workflow_app_id == tool.workflow_app_id + assert forked.runtime.tenant_id == "tenant-2" + + +def test_derive_usage_from_top_level_usage_key(): + """Derive usage from top-level usage dict.""" + usage = WorkflowTool._derive_usage_from_result({"usage": {"total_tokens": 12, "total_price": "0.2"}}) + assert usage.total_tokens == 12 + + +def test_derive_usage_from_metadata_usage(): + """Derive usage from metadata usage dict.""" + metadata_usage = WorkflowTool._derive_usage_from_result({"metadata": {"usage": {"total_tokens": 7}}}) + assert metadata_usage.total_tokens == 7 + + +def test_derive_usage_from_totals(): + """Derive usage from top-level totals fields.""" + totals_usage = WorkflowTool._derive_usage_from_result( + {"total_tokens": "9", "total_price": "1.3", "currency": "USD"} + ) + assert totals_usage.total_tokens == 9 + assert str(totals_usage.total_price) == "1.3" + + +def test_derive_usage_from_empty(): + """Default usage values when result is empty.""" + empty_usage = WorkflowTool._derive_usage_from_result({}) + assert empty_usage.total_tokens == 0 + + +def test_extract_usage_from_nested(): + """Extract nested usage dict from result payloads.""" + nested = WorkflowTool._extract_usage_dict({"nested": [{"data": {"usage": {"total_tokens": 3}}}]}) + assert nested == {"total_tokens": 3} + + +def test_invoke_raises_when_user_not_found(monkeypatch: pytest.MonkeyPatch): + """Raise ToolInvokeError when user resolution fails.""" + tool = _build_tool() + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: None) + + with pytest.raises(ToolInvokeError, match="User not found"): + list(tool.invoke("missing", {})) + + +def test_resolve_user_from_database_returns_account(monkeypatch: pytest.MonkeyPatch): + """Resolve Account and set tenant in worker context.""" + tenant = SimpleNamespace(id="tenant_id") + account = SimpleNamespace(id="account_id", current_tenant=None) + session = StubSession(scalar_results=[tenant, account]) + + monkeypatch.setattr("core.tools.workflow_as_tool.tool.session_factory.create_session", lambda: session) + tool = _build_tool() + tool.runtime.tenant_id = "tenant_id" + + resolved = tool._resolve_user_from_database(user_id="account_id") + assert resolved is account + assert account.current_tenant is tenant + assert session.expunge_calls == [account] + + +def test_get_workflow_and_get_app_db_branches(monkeypatch: pytest.MonkeyPatch): + """Cover workflow/app retrieval branches and error cases.""" + tool = _build_tool() + latest_workflow = SimpleNamespace(id="wf-latest") + specific_workflow = SimpleNamespace(id="wf-v1") + app = SimpleNamespace(id="app-1") + sessions = iter( + [ + StubSession(scalar_results=[], scalars_results=[latest_workflow]), + StubSession(scalar_results=[specific_workflow], scalars_results=[]), + StubSession(scalar_results=[app], scalars_results=[]), + ] + ) + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: next(sessions), + ) + + assert tool._get_workflow("app-1", "") is latest_workflow + assert tool._get_workflow("app-1", "1") is specific_workflow + assert tool._get_app("app-1") is app + + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: StubSession(scalar_results=[None, None], scalars_results=[None]), + ) + with pytest.raises(ValueError, match="workflow not found"): + tool._get_workflow("app-1", "1") + with pytest.raises(ValueError, match="app not found"): + tool._get_app("app-1") + + +def _setup_transform_args_tool(monkeypatch: pytest.MonkeyPatch) -> WorkflowTool: + """Build a WorkflowTool and stub merged runtime parameters for files/query.""" + tool = _build_tool() + files_param = ToolParameter.get_simple_instance( + name="files", + llm_description="files", + typ=ToolParameter.ToolParameterType.SYSTEM_FILES, + required=False, + ) + files_param.form = ToolParameter.ToolParameterForm.FORM + text_param = ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + text_param.form = ToolParameter.ToolParameterForm.FORM + + monkeypatch.setattr(tool, "get_merged_runtime_parameters", lambda: [files_param, text_param]) + return tool + + +def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch): + """Transform args into parameters and files payloads.""" + tool = _setup_transform_args_tool(monkeypatch) + + params, files = tool._transform_args( + { + "query": "hello", + "files": [ + { + "tenant_id": "tenant-1", + "type": "image", + "transfer_method": "tool_file", + "related_id": "tool-1", + "extension": ".png", + }, + { + "tenant_id": "tenant-1", + "type": "document", + "transfer_method": "local_file", + "related_id": "upload-1", + }, + { + "tenant_id": "tenant-1", + "type": "document", + "transfer_method": "remote_url", + "remote_url": "https://example.com/a.pdf", + }, + ], + } + ) + assert params == {"query": "hello"} + assert any(file_item.get("tool_file_id") == "tool-1" for file_item in files) + assert any(file_item.get("upload_file_id") == "upload-1" for file_item in files) + assert any(file_item.get("url") == "https://example.com/a.pdf" for file_item in files) + + +def test_transform_args_invalid_files(monkeypatch: pytest.MonkeyPatch): + """Ignore invalid file entries while keeping params.""" + tool = _setup_transform_args_tool(monkeypatch) + invalid_params, invalid_files = tool._transform_args({"query": "hello", "files": [{"invalid": True}]}) + assert invalid_params == {"query": "hello"} + assert invalid_files == [] + + +def test_extract_files(): + """Extract file outputs into result and file list.""" + tool = _build_tool() + built_files = [ + SimpleNamespace(id="file-1"), + SimpleNamespace(id="file-2"), + ] + with patch("core.tools.workflow_as_tool.tool.build_from_mapping", side_effect=built_files): + outputs = { + "attachments": [ + { + "dify_model_identity": FILE_MODEL_IDENTITY, + "transfer_method": "tool_file", + "related_id": "r1", + } + ], + "single_file": { + "dify_model_identity": FILE_MODEL_IDENTITY, + "transfer_method": "local_file", + "related_id": "r2", + }, + "text": "ok", + } + result, extracted_files = tool._extract_files(outputs) + + assert result["text"] == "ok" + assert len(extracted_files) == 2 + + +def test_update_file_mapping(): + """Map tool/local file transfer methods into output shape.""" + tool = _build_tool() + tool_file = tool._update_file_mapping({"transfer_method": "tool_file", "related_id": "tool-1"}) + assert tool_file["tool_file_id"] == "tool-1" + local_file = tool._update_file_mapping({"transfer_method": "local_file", "related_id": "upload-1"}) + assert local_file["upload_file_id"] == "upload-1" diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index d47d4d6130..91259c9a45 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -1,11 +1,14 @@ import dataclasses +import orjson +import pytest from pydantic import BaseModel from core.helper import encrypter from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.runtime import VariablePool from dify_graph.system_variable import SystemVariable +from dify_graph.variables.segment_group import SegmentGroup from dify_graph.variables.segments import ( ArrayAnySegment, ArrayFileSegment, @@ -23,6 +26,11 @@ from dify_graph.variables.segments import ( get_segment_discriminator, ) from dify_graph.variables.types import SegmentType +from dify_graph.variables.utils import ( + dumps_with_segments, + segment_orjson_default, + to_selector, +) from dify_graph.variables.variables import ( ArrayAnyVariable, ArrayFileVariable, @@ -379,3 +387,125 @@ class TestSegmentDumpAndLoad: assert get_segment_discriminator("not_a_dict") is None assert get_segment_discriminator(42) is None assert get_segment_discriminator(object) is None + + +class TestSegmentAdditionalProperties: + def test_base_segment_text_log_markdown_size_and_to_object(self): + """Ensure StringSegment exposes text, log, markdown, size and to_object.""" + segment = StringSegment(value="hello") + + assert segment.text == "hello" + assert segment.log == "hello" + assert segment.markdown == "hello" + assert segment.size > 0 + assert segment.to_object() == "hello" + + def test_none_segment_empty_outputs(self): + """Ensure NoneSegment renders empty text, log and markdown.""" + segment = NoneSegment() + + assert segment.text == "" + assert segment.log == "" + assert segment.markdown == "" + + def test_object_segment_json_outputs(self): + """Ensure ObjectSegment renders JSON output for text, log and markdown.""" + segment = ObjectSegment(value={"key": "值", "n": 1}) + + assert segment.text == '{"key": "值", "n": 1}' + assert segment.log == '{\n "key": "值",\n "n": 1\n}' + assert segment.markdown == '{\n "key": "值",\n "n": 1\n}' + + def test_array_segment_text_and_markdown(self): + """Ensure ArrayAnySegment handles empty/non-empty text and markdown rendering.""" + empty_segment = ArrayAnySegment(value=[]) + non_empty_segment = ArrayAnySegment(value=[1, "two"]) + + assert empty_segment.text == "" + assert non_empty_segment.text == "[1, 'two']" + assert non_empty_segment.markdown == "- 1\n- two" + + def test_file_segment_properties(self): + """Ensure FileSegment markdown, text and log fields match expected behavior.""" + file = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="doc.txt") + segment = FileSegment(value=file) + + assert segment.markdown == "[doc.txt](https://example.com/file.txt)" + assert segment.log == "" + assert segment.text == "" + + def test_array_string_segment_text_branches(self): + """Ensure ArrayStringSegment text handling for empty and non-empty values.""" + empty_segment = ArrayStringSegment(value=[]) + non_empty_segment = ArrayStringSegment(value=["hello", "世界"]) + + assert empty_segment.text == "" + assert non_empty_segment.text == '["hello", "世界"]' + + def test_array_file_segment_markdown_and_empty_text_log(self): + """Ensure ArrayFileSegment markdown renders links and text/log stay empty.""" + file1 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="a.txt") + file2 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="b.txt") + segment = ArrayFileSegment(value=[file1, file2]) + + assert segment.markdown == "[a.txt](https://example.com/file.txt)\n[b.txt](https://example.com/file.txt)" + assert segment.log == "" + assert segment.text == "" + + +class TestSegmentGroupAdditional: + def test_segment_group_markdown_and_to_object(self): + group = SegmentGroup(value=[StringSegment(value="A"), NoneSegment(), StringSegment(value="B")]) + + assert group.markdown == "AB" + assert group.to_object() == ["A", None, "B"] + + +class TestSegmentUtils: + def test_to_selector_without_paths(self): + assert to_selector("node-1", "output") == ["node-1", "output"] + + def test_to_selector_with_paths(self): + assert to_selector("node-1", "output", ("a", "b")) == ["node-1", "output", "a", "b"] + + def test_array_file_segment_serialization(self): + file1 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="a.txt") + file2 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="b.txt") + + result = segment_orjson_default(ArrayFileSegment(value=[file1, file2])) + + assert len(result) == 2 + assert result[0]["filename"] == "a.txt" + assert result[1]["filename"] == "b.txt" + + def test_file_segment_serialization(self): + file = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="single.txt") + + result = segment_orjson_default(FileSegment(value=file)) + + assert result["filename"] == "single.txt" + assert result["remote_url"] == "https://example.com/file.txt" + + def test_segment_group_and_segment_serialization(self): + group = SegmentGroup(value=[StringSegment(value="a"), StringSegment(value="b")]) + + assert segment_orjson_default(group) == ["a", "b"] + assert segment_orjson_default(StringSegment(value="value")) == "value" + + def test_segment_orjson_default_unsupported_type(self): + with pytest.raises(TypeError, match="not JSON serializable"): + segment_orjson_default(object()) + + def test_dumps_with_segments(self): + data = { + "segment": StringSegment(value="hello"), + "group": SegmentGroup(value=[StringSegment(value="x"), StringSegment(value="y")]), + 1: "numeric-key", + } + + dumped = dumps_with_segments(data) + loaded = orjson.loads(dumped) + + assert loaded["segment"] == "hello" + assert loaded["group"] == ["x", "y"] + assert loaded["1"] == "numeric-key" diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index c3371d92e3..bb234d9bbd 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,5 +1,7 @@ import pytest +from dify_graph.variables.segment_group import SegmentGroup +from dify_graph.variables.segments import StringSegment from dify_graph.variables.types import ArrayValidation, SegmentType @@ -69,22 +71,36 @@ class TestSegmentTypeIsValidArrayValidation: """ def test_array_validation_all_success(self): + # Arrange value = ["hello", "world", "foo"] - assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + # Act + is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + # Assert + assert is_valid def test_array_validation_all_fail(self): + # Arrange value = ["hello", 123, "world"] - # Should return False, since 123 is not a string - assert not SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + # Act + is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + # Assert + assert not is_valid def test_array_validation_first(self): + # Arrange value = ["hello", 123, None] - assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST) + # Act + is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST) + # Assert + assert is_valid def test_array_validation_none(self): + # Arrange value = [1, 2, 3] - # validation is None, skip - assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE) + # Act + is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE) + # Assert + assert is_valid class TestSegmentTypeGetZeroValue: @@ -163,3 +179,62 @@ class TestSegmentTypeGetZeroValue: for seg_type in unsupported_types: with pytest.raises(ValueError, match="unsupported variable type"): SegmentType.get_zero_value(seg_type) + + +class TestSegmentTypeInferSegmentType: + @pytest.mark.parametrize( + ("value", "expected"), + [ + ([], SegmentType.ARRAY_NUMBER), + ([1, 2, 3], SegmentType.ARRAY_NUMBER), + ([1, 2.5], SegmentType.ARRAY_NUMBER), + (["a", "b"], SegmentType.ARRAY_STRING), + ([{"k": "v"}], SegmentType.ARRAY_OBJECT), + ([None], SegmentType.ARRAY_ANY), + ([True, False], SegmentType.ARRAY_BOOLEAN), + ([[1], [2]], SegmentType.ARRAY_ANY), + ([1, "a"], SegmentType.ARRAY_ANY), + (None, SegmentType.NONE), + (True, SegmentType.BOOLEAN), + (1, SegmentType.INTEGER), + (1.2, SegmentType.FLOAT), + ("abc", SegmentType.STRING), + ({"k": "v"}, SegmentType.OBJECT), + ], + ) + def test_infer_segment_type_supported_values(self, value, expected): + assert SegmentType.infer_segment_type(value) == expected + + +class TestSegmentTypeAdditionalMethods: + def test_cast_value_for_bool_number_and_array_number(self): + assert SegmentType.cast_value(True, SegmentType.INTEGER) == 1 + assert SegmentType.cast_value(False, SegmentType.NUMBER) == 0 + assert SegmentType.cast_value([True, False], SegmentType.ARRAY_NUMBER) == [1, 0] + + mixed = [True, 1] + assert SegmentType.cast_value(mixed, SegmentType.ARRAY_NUMBER) is mixed + assert SegmentType.cast_value("x", SegmentType.STRING) == "x" + + def test_exposed_type_and_element_type(self): + assert SegmentType.INTEGER.exposed_type() == SegmentType.NUMBER + assert SegmentType.FLOAT.exposed_type() == SegmentType.NUMBER + assert SegmentType.STRING.exposed_type() == SegmentType.STRING + + assert SegmentType.ARRAY_STRING.element_type() == SegmentType.STRING + assert SegmentType.ARRAY_ANY.element_type() is None + + with pytest.raises(ValueError, match="element_type is only supported by array type"): + SegmentType.STRING.element_type() + + def test_group_validation_for_segment_group_and_list(self): + valid_group = SegmentGroup(value=[StringSegment(value="a")]) + assert SegmentType.GROUP.is_valid(valid_group) is True + assert SegmentType.GROUP.is_valid([StringSegment(value="b")]) is True + assert SegmentType.GROUP.is_valid(["not-segment"]) is False + + def test_unreachable_assertion_branch(self, monkeypatch): + monkeypatch.setattr(SegmentType, "is_array_type", lambda self: False) + + with pytest.raises(AssertionError, match="unreachable"): + SegmentType.ARRAY_STRING.is_valid(["a"]) diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index b98d56147e..9e9fc2e9ec 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -7,10 +7,10 @@ from dataclasses import dataclass import pytest from dify_graph.entities import GraphInitParams +from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeType from dify_graph.graph import Graph from dify_graph.graph.validation import GraphValidationError -from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable @@ -183,3 +183,36 @@ def test_graph_validation_blocks_start_and_trigger_coexistence( Graph.init(graph_config=graph_config, node_factory=node_factory) assert any(issue.code == "TRIGGER_START_NODE_CONFLICT" for issue in exc_info.value.issues) + + +def test_graph_init_ignores_custom_note_nodes_before_node_data_validation( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +) -> None: + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + { + "id": "start", + "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, + }, + {"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}}, + { + "id": "note", + "type": "custom-note", + "data": { + "type": "", + "title": "", + "desc": "", + "text": "{}", + "theme": "blue", + }, + }, + ] + graph_config["edges"] = [ + {"source": "start", "target": "answer", "sourceHandle": "success"}, + ] + + graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert graph.root_node.id == "start" + assert "answer" in graph.nodes + assert "note" not in graph.nodes diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py index 61e0f12550..2b926d754c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py @@ -2,6 +2,7 @@ from __future__ import annotations +from dify_graph.entities.base_node_data import RetryConfig from dify_graph.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.graph_engine.domain.graph_execution import GraphExecution @@ -12,7 +13,6 @@ from dify_graph.graph_engine.ready_queue.in_memory import InMemoryReadyQueue from dify_graph.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator from dify_graph.graph_events import NodeRunRetryEvent, NodeRunStartedEvent from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.entities import RetryConfig from dify_graph.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index b9ae680f52..4e13177d2b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -10,6 +10,7 @@ import time from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st +from dify_graph.entities.base_node_data import DefaultValue, DefaultValueType from dify_graph.enums import ErrorStrategy from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel @@ -18,7 +19,6 @@ from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, ) -from dify_graph.nodes.base.entities import DefaultValue, DefaultValueType # Import the test framework from the new module from .test_mock_config import MockConfigBuilder diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 9f33a81985..338db9076e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -5,10 +5,10 @@ This module provides a MockNodeFactory that automatically detects and mocks node requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request). """ -from collections.abc import Mapping from typing import TYPE_CHECKING, Any from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.enums import NodeType from dify_graph.nodes.base.node import Node @@ -75,39 +75,27 @@ class MockNodeFactory(DifyNodeFactory): NodeType.CODE: MockCodeNode, } - def create_node(self, node_config: Mapping[str, Any]) -> Node: + def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node: """ Create a node instance, using mock implementations for third-party service nodes. :param node_config: Node configuration dictionary :return: Node instance (real or mocked) """ - # Get node type from config - node_data = node_config.get("data", {}) - node_type_str = node_data.get("type") - - if not node_type_str: - # Fall back to parent implementation for nodes without type - return super().create_node(node_config) - - try: - node_type = NodeType(node_type_str) - except ValueError: - # Unknown node type, use parent implementation - return super().create_node(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + node_data = typed_node_config["data"] + node_type = node_data.type # Check if this node type should be mocked if node_type in self._mock_node_types: - node_id = node_config.get("id") - if not node_id: - raise ValueError("Node config missing id") + node_id = typed_node_config["id"] # Create mock node instance mock_class = self._mock_node_types[node_type] if node_type == NodeType.CODE: mock_instance = mock_class( id=node_id, - config=node_config, + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -117,7 +105,7 @@ class MockNodeFactory(DifyNodeFactory): elif node_type == NodeType.HTTP_REQUEST: mock_instance = mock_class( id=node_id, - config=node_config, + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -129,7 +117,7 @@ class MockNodeFactory(DifyNodeFactory): elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}: mock_instance = mock_class( id=node_id, - config=node_config, + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -139,7 +127,7 @@ class MockNodeFactory(DifyNodeFactory): else: mock_instance = mock_class( id=node_id, - config=node_config, + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -148,7 +136,7 @@ class MockNodeFactory(DifyNodeFactory): return mock_instance # For non-mocked node types, use parent implementation - return super().create_node(node_config) + return super().create_node(typed_node_config) def should_mock_node(self, node_type: NodeType) -> bool: """ diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 34e714a227..9e3574266c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -11,10 +11,10 @@ from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock from core.model_manager import ModelInstance +from core.workflow.nodes.agent import AgentNode from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from dify_graph.nodes.agent import AgentNode from dify_graph.nodes.code import CodeNode from dify_graph.nodes.document_extractor import DocumentExtractorNode from dify_graph.nodes.http_request import HttpRequestNode @@ -79,6 +79,14 @@ class MockNodeMixin: if isinstance(self, _ToolNode): kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + if isinstance(self, AgentNode): + presentation_provider = MagicMock() + presentation_provider.get_icon.return_value = None + kwargs.setdefault("strategy_resolver", MagicMock()) + kwargs.setdefault("presentation_provider", presentation_provider) + kwargs.setdefault("runtime_support", MagicMock()) + kwargs.setdefault("message_transformer", MagicMock()) + super().__init__( id=id, config=config, diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index bf814d0c97..3fb775f934 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,7 +1,7 @@ import pytest +from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import NodeType -from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node # Ensures that all node classes are imported. @@ -126,3 +126,20 @@ def test_init_subclass_sets_node_data_type_from_generic(): return "1" assert _AutoNode._node_data_type is _TestNodeData + + +def test_validate_node_data_uses_declared_node_data_type(): + """Public validation should hydrate the subclass-declared node data model.""" + + class _AutoNode(Node[_TestNodeData]): + node_type = NodeType.CODE + + @staticmethod + def version() -> str: + return "1" + + base_node_data = BaseNodeData.model_validate({"type": NodeType.CODE, "title": "Test"}) + + validated = _AutoNode.validate_node_data(base_node_data) + + assert isinstance(validated, _TestNodeData) diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index f8d799e446..86d326aead 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,8 +1,8 @@ import types from collections.abc import Mapping +from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import NodeType -from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node # Import concrete nodes we will assert on (numeric version path) diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index 95cb653635..784e08edd2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -272,7 +272,7 @@ class TestCodeNodeExtractVariableSelector: result = CodeNode._extract_variable_selector_to_variable_mapping( graph_config={}, node_id="node_1", - node_data=node_data, + node_data=CodeNodeData.model_validate(node_data, from_attributes=True), ) assert result == {} @@ -292,7 +292,7 @@ class TestCodeNodeExtractVariableSelector: result = CodeNode._extract_variable_selector_to_variable_mapping( graph_config={}, node_id="node_1", - node_data=node_data, + node_data=CodeNodeData.model_validate(node_data, from_attributes=True), ) assert "node_1.input_text" in result @@ -315,7 +315,7 @@ class TestCodeNodeExtractVariableSelector: result = CodeNode._extract_variable_selector_to_variable_mapping( graph_config={}, node_id="code_node", - node_data=node_data, + node_data=CodeNodeData.model_validate(node_data, from_attributes=True), ) assert len(result) == 3 @@ -338,7 +338,7 @@ class TestCodeNodeExtractVariableSelector: result = CodeNode._extract_variable_selector_to_variable_mapping( graph_config={}, node_id="node_x", - node_data=node_data, + node_data=CodeNodeData.model_validate(node_data, from_attributes=True), ) assert result["node_x.deep_var"] == ["node", "obj", "nested", "value"] @@ -437,7 +437,7 @@ class TestCodeNodeInitialization: "outputs": {"x": {"type": "number"}}, } - node._node_data = node._hydrate_node_data(data) + node._node_data = CodeNode._node_data_type.model_validate(data, from_attributes=True) assert node._node_data.title == "Test Node" assert node._node_data.code_language == CodeLanguage.PYTHON3 @@ -453,7 +453,7 @@ class TestCodeNodeInitialization: "outputs": {"x": {"type": "number"}}, } - node._node_data = node._hydrate_node_data(data) + node._node_data = CodeNode._node_data_type.model_validate(data, from_attributes=True) assert node._node_data.code_language == CodeLanguage.JAVASCRIPT diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py index b95a7ad8ae..490df52533 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py @@ -1,3 +1,4 @@ +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import NodeType from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from dify_graph.nodes.iteration.exc import ( @@ -388,3 +389,50 @@ class TestIterationNodeErrorStrategies: result = node._get_default_value_dict() assert isinstance(result, dict) + + +def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: + seen_configs: list[object] = [] + original_validate_python = NodeConfigDictAdapter.validate_python + + def record_validate_python(value: object): + seen_configs.append(value) + return original_validate_python(value) + + monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) + + child_node_config = { + "id": "answer-node", + "data": { + "type": "answer", + "title": "Answer", + "answer": "", + "iteration_id": "iteration-node", + }, + } + + IterationNode._extract_variable_selector_to_variable_mapping( + graph_config={ + "nodes": [ + { + "id": "iteration-node", + "data": { + "type": "iteration", + "title": "Iteration", + "iterator_selector": ["start", "items"], + "output_selector": ["iteration", "result"], + }, + }, + child_node_config, + ], + "edges": [], + }, + node_id="iteration-node", + node_data=IterationNodeData( + title="Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration", "result"], + ), + ) + + assert seen_configs == [child_node_config] diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index e929d652fd..b7a7a9c938 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -410,14 +410,14 @@ class TestKnowledgeRetrievalNode: """Test _extract_variable_selector_to_variable_mapping class method.""" # Arrange node_id = "knowledge_node_1" - node_data = { - "type": "knowledge-retrieval", - "title": "Knowledge Retrieval", - "dataset_ids": [str(uuid.uuid4())], - "retrieval_mode": "multiple", - "query_variable_selector": ["start", "query"], - "query_attachment_selector": ["start", "attachments"], - } + node_data = KnowledgeRetrievalNodeData( + type="knowledge-retrieval", + title="Knowledge Retrieval", + dataset_ids=[str(uuid.uuid4())], + retrieval_mode="multiple", + query_variable_selector=["start", "query"], + query_attachment_selector=["start", "attachments"], + ) graph_config = {} # Act diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 44abf430c0..0d81e7762b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -4,8 +4,9 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.enums import NodeType -from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable @@ -40,13 +41,26 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state +def _build_node_config() -> NodeConfigDict: + return NodeConfigDictAdapter.validate_python( + { + "id": "node-1", + "data": { + "type": NodeType.ANSWER.value, + "title": "Sample", + "foo": "bar", + }, + } + ) + + def test_node_hydrates_data_during_initialization(): graph_config: dict[str, object] = {} init_params, runtime_state = _build_context(graph_config) node = _SampleNode( id="node-1", - config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}}, + config=_build_node_config(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -72,7 +86,7 @@ def test_node_accepts_invoke_from_enum(): node = _SampleNode( id="node-1", - config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}}, + config=_build_node_config(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -99,3 +113,17 @@ def test_missing_generic_argument_raises_type_error(): def _run(self): raise NotImplementedError + + +def test_base_node_data_keeps_dict_style_access_compatibility(): + node_data = _SampleNodeData.model_validate( + { + "type": NodeType.ANSWER.value, + "title": "Sample", + "foo": "bar", + } + ) + + assert node_data["foo"] == "bar" + assert node_data.get("foo") == "bar" + assert node_data.get("missing", "fallback") == "fallback" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py new file mode 100644 index 0000000000..6372583839 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py @@ -0,0 +1,52 @@ +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.nodes.loop.entities import LoopNodeData +from dify_graph.nodes.loop.loop_node import LoopNode + + +def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: + seen_configs: list[object] = [] + original_validate_python = NodeConfigDictAdapter.validate_python + + def record_validate_python(value: object): + seen_configs.append(value) + return original_validate_python(value) + + monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) + + child_node_config = { + "id": "answer-node", + "data": { + "type": "answer", + "title": "Answer", + "answer": "", + "loop_id": "loop-node", + }, + } + + LoopNode._extract_variable_selector_to_variable_mapping( + graph_config={ + "nodes": [ + { + "id": "loop-node", + "data": { + "type": "loop", + "title": "Loop", + "loop_count": 1, + "break_conditions": [], + "logical_operator": "and", + }, + }, + child_node_config, + ], + "edges": [], + }, + node_id="loop-node", + node_data=LoopNodeData( + title="Loop", + loop_count=1, + break_conditions=[], + logical_operator="and", + ), + ) + + assert seen_configs == [child_node_config] diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py index 410c4993e4..61b18566b0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py @@ -210,9 +210,6 @@ def test_webhook_data_model_dump_with_alias(): def test_webhook_data_validation_errors(): """Test WebhookData validation errors.""" - # Title is required (inherited from BaseNodeData) - with pytest.raises(ValidationError): - WebhookData() # Invalid method with pytest.raises(ValidationError): @@ -254,6 +251,36 @@ def test_webhook_data_sequence_fields(): assert len(data.headers) == 1 # Should still be 1 +def test_webhook_data_rejects_non_string_header_types(): + """Headers should stay string-only because runtime does not coerce header values.""" + for param_type in ["number", "boolean", "object", "array[string]", "file"]: + with pytest.raises(ValidationError): + WebhookData( + title="Test", + headers=[WebhookParameter(name="X-Test", type=param_type)], + ) + + +def test_webhook_data_limits_query_param_types_to_scalar_values(): + """Query params only support scalar conversions in the current runtime.""" + data = WebhookData( + title="Test", + params=[ + WebhookParameter(name="count", type="number"), + WebhookParameter(name="enabled", type="boolean"), + ], + ) + assert data.params[0].type == "number" + assert data.params[1].type == "boolean" + + for param_type in ["object", "array[string]", "array[number]", "array[boolean]", "array[object]", "file"]: + with pytest.raises(ValidationError): + WebhookData( + title="Test", + params=[WebhookParameter(name="test", type=param_type)], + ) + + def test_webhook_data_sync_mode(): """Test WebhookData SyncMode nested enum.""" # Test that SyncMode enum exists and has expected value @@ -297,7 +324,7 @@ def test_webhook_body_parameter_edge_cases(): def test_webhook_data_inheritance(): """Test WebhookData inherits from BaseNodeData correctly.""" - from dify_graph.nodes.base import BaseNodeData + from dify_graph.entities.base_node_data import BaseNodeData # Test that WebhookData is a subclass of BaseNodeData assert issubclass(WebhookData, BaseNodeData) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index f2273e441e..a821e361c5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,6 +1,6 @@ import pytest -from dify_graph.nodes.base.exc import BaseNodeError +from dify_graph.entities.exc import BaseNodeError from dify_graph.nodes.trigger_webhook.exc import ( WebhookConfigError, WebhookNodeError, diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py new file mode 100644 index 0000000000..934e29546c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -0,0 +1,602 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, sentinel + +import pytest + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom +from core.workflow import node_factory +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import NodeType, SystemVariableKey +from dify_graph.nodes.code.entities import CodeLanguage +from dify_graph.variables.segments import StringSegment + + +def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None: + assert config["id"] == node_id + assert isinstance(config["data"], BaseNodeData) + assert config["data"].type == node_type + assert config["data"].version == version + + +class TestFetchMemory: + @pytest.mark.parametrize( + ("conversation_id", "memory_config"), + [ + (None, object()), + ("conversation-id", None), + ], + ) + def test_returns_none_when_memory_or_conversation_is_missing(self, conversation_id, memory_config): + result = node_factory.fetch_memory( + conversation_id=conversation_id, + app_id="app-id", + node_data_memory=memory_config, + model_instance=sentinel.model_instance, + ) + + assert result is None + + def test_returns_none_when_conversation_does_not_exist(self, monkeypatch): + class FakeSelect: + def where(self, *_args): + return self + + class FakeSession: + def __init__(self, *_args, **_kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *_args): + return False + + def scalar(self, _stmt): + return None + + monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine)) + monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect())) + monkeypatch.setattr(node_factory, "Session", FakeSession) + + result = node_factory.fetch_memory( + conversation_id="conversation-id", + app_id="app-id", + node_data_memory=object(), + model_instance=sentinel.model_instance, + ) + + assert result is None + + def test_builds_token_buffer_memory_for_existing_conversation(self, monkeypatch): + conversation = sentinel.conversation + memory = sentinel.memory + + class FakeSelect: + def where(self, *_args): + return self + + class FakeSession: + def __init__(self, *_args, **_kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *_args): + return False + + def scalar(self, _stmt): + return conversation + + token_buffer_memory = MagicMock(return_value=memory) + monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine)) + monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect())) + monkeypatch.setattr(node_factory, "Session", FakeSession) + monkeypatch.setattr(node_factory, "TokenBufferMemory", token_buffer_memory) + + result = node_factory.fetch_memory( + conversation_id="conversation-id", + app_id="app-id", + node_data_memory=object(), + model_instance=sentinel.model_instance, + ) + + assert result is memory + token_buffer_memory.assert_called_once_with( + conversation=conversation, + model_instance=sentinel.model_instance, + ) + + +class TestDefaultWorkflowCodeExecutor: + def test_execute_delegates_to_code_executor(self, monkeypatch): + executor = node_factory.DefaultWorkflowCodeExecutor() + execute_workflow_code_template = MagicMock(return_value={"answer": "ok"}) + monkeypatch.setattr( + node_factory.CodeExecutor, + "execute_workflow_code_template", + execute_workflow_code_template, + ) + + result = executor.execute( + language=CodeLanguage.PYTHON3, + code="print('ok')", + inputs={"name": "workflow"}, + ) + + assert result == {"answer": "ok"} + execute_workflow_code_template.assert_called_once_with( + language=CodeLanguage.PYTHON3, + code="print('ok')", + inputs={"name": "workflow"}, + ) + + def test_is_execution_error_checks_code_execution_error_type(self): + executor = node_factory.DefaultWorkflowCodeExecutor() + + assert executor.is_execution_error(node_factory.CodeExecutionError("boom")) is True + assert executor.is_execution_error(RuntimeError("boom")) is False + + +class TestDifyNodeFactoryInit: + def test_init_builds_default_dependencies(self): + graph_init_params = SimpleNamespace(run_context={"context": "value"}) + graph_runtime_state = sentinel.graph_runtime_state + dify_context = SimpleNamespace(tenant_id="tenant-id") + template_renderer = sentinel.template_renderer + rag_retrieval = sentinel.rag_retrieval + unstructured_api_config = sentinel.unstructured_api_config + http_request_config = sentinel.http_request_config + credentials_provider = sentinel.credentials_provider + model_factory = sentinel.model_factory + + with ( + patch.object( + node_factory.DifyNodeFactory, + "_resolve_dify_context", + return_value=dify_context, + ) as resolve_dify_context, + patch.object( + node_factory, + "CodeExecutorJinja2TemplateRenderer", + return_value=template_renderer, + ) as renderer_factory, + patch.object(node_factory, "DatasetRetrieval", return_value=rag_retrieval), + patch.object( + node_factory, + "UnstructuredApiConfig", + return_value=unstructured_api_config, + ), + patch.object( + node_factory, + "build_http_request_config", + return_value=http_request_config, + ), + patch.object( + node_factory, + "build_dify_model_access", + return_value=(credentials_provider, model_factory), + ) as build_dify_model_access, + ): + factory = node_factory.DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + resolve_dify_context.assert_called_once_with(graph_init_params.run_context) + build_dify_model_access.assert_called_once_with("tenant-id") + renderer_factory.assert_called_once() + assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor + assert factory.graph_init_params is graph_init_params + assert factory.graph_runtime_state is graph_runtime_state + assert factory._dify_context is dify_context + assert factory._template_renderer is template_renderer + assert factory._rag_retrieval is rag_retrieval + assert factory._document_extractor_unstructured_api_config is unstructured_api_config + assert factory._http_request_config is http_request_config + assert factory._llm_credentials_provider is credentials_provider + assert factory._llm_model_factory is model_factory + + +class TestDifyNodeFactoryResolveContext: + def test_requires_reserved_context_key(self): + with pytest.raises(ValueError, match=DIFY_RUN_CONTEXT_KEY): + node_factory.DifyNodeFactory._resolve_dify_context({}) + + def test_returns_existing_dify_context(self): + dify_context = DifyRunContext( + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + result = node_factory.DifyNodeFactory._resolve_dify_context({DIFY_RUN_CONTEXT_KEY: dify_context}) + + assert result is dify_context + + def test_validates_mapping_context(self): + raw_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant-id", + "app_id": "app-id", + "user_id": "user-id", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + } + + result = node_factory.DifyNodeFactory._resolve_dify_context(raw_context) + + assert isinstance(result, DifyRunContext) + assert result.tenant_id == "tenant-id" + + +class TestDifyNodeFactoryCreateNode: + @pytest.fixture + def factory(self): + factory = object.__new__(node_factory.DifyNodeFactory) + factory.graph_init_params = sentinel.graph_init_params + factory.graph_runtime_state = sentinel.graph_runtime_state + factory._dify_context = SimpleNamespace(tenant_id="tenant-id", app_id="app-id") + factory._code_executor = sentinel.code_executor + factory._code_limits = sentinel.code_limits + factory._template_renderer = sentinel.template_renderer + factory._template_transform_max_output_length = 2048 + factory._http_request_http_client = sentinel.http_client + factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory + factory._http_request_file_manager = sentinel.file_manager + factory._rag_retrieval = sentinel.rag_retrieval + factory._document_extractor_unstructured_api_config = sentinel.unstructured_api_config + factory._http_request_config = sentinel.http_request_config + factory._llm_credentials_provider = sentinel.credentials_provider + factory._llm_model_factory = sentinel.model_factory + return factory + + def test_rejects_unknown_node_type(self, factory): + with pytest.raises(ValueError, match="Input should be"): + factory.create_node({"id": "node-id", "data": {"type": "missing"}}) + + def test_rejects_missing_class_mapping(self, monkeypatch, factory): + monkeypatch.setattr( + node_factory, + "resolve_workflow_node_class", + MagicMock(side_effect=ValueError("No class mapping found for node type: start")), + ) + + with pytest.raises(ValueError, match="No class mapping found for node type: start"): + factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}}) + + def test_rejects_missing_latest_class(self, monkeypatch, factory): + monkeypatch.setattr( + node_factory, + "resolve_workflow_node_class", + MagicMock(side_effect=ValueError("No latest version class found for node type: start")), + ) + + with pytest.raises(ValueError, match="No latest version class found for node type: start"): + factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}}) + + def test_uses_version_specific_class_when_available(self, monkeypatch, factory): + matched_node = sentinel.matched_node + latest_node_class = MagicMock(return_value=sentinel.latest_node) + matched_node_class = MagicMock(return_value=matched_node) + monkeypatch.setattr( + node_factory, + "resolve_workflow_node_class", + MagicMock(return_value=matched_node_class), + ) + + result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}}) + + assert result is matched_node + matched_node_class.assert_called_once() + kwargs = matched_node_class.call_args.kwargs + assert kwargs["id"] == "node-id" + _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9") + assert kwargs["graph_init_params"] is sentinel.graph_init_params + assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + latest_node_class.assert_not_called() + + def test_falls_back_to_latest_class_when_version_specific_mapping_is_missing(self, monkeypatch, factory): + latest_node = sentinel.latest_node + latest_node_class = MagicMock(return_value=latest_node) + monkeypatch.setattr( + node_factory, + "resolve_workflow_node_class", + MagicMock(return_value=latest_node_class), + ) + + result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}}) + + assert result is latest_node + latest_node_class.assert_called_once() + kwargs = latest_node_class.call_args.kwargs + assert kwargs["id"] == "node-id" + _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9") + assert kwargs["graph_init_params"] is sentinel.graph_init_params + assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + + @pytest.mark.parametrize( + ("node_type", "constructor_name"), + [ + (NodeType.CODE, "CodeNode"), + (NodeType.TEMPLATE_TRANSFORM, "TemplateTransformNode"), + (NodeType.HTTP_REQUEST, "HttpRequestNode"), + (NodeType.HUMAN_INPUT, "HumanInputNode"), + (NodeType.KNOWLEDGE_INDEX, "KnowledgeIndexNode"), + (NodeType.DATASOURCE, "DatasourceNode"), + (NodeType.KNOWLEDGE_RETRIEVAL, "KnowledgeRetrievalNode"), + (NodeType.DOCUMENT_EXTRACTOR, "DocumentExtractorNode"), + ], + ) + def test_creates_specialized_nodes(self, monkeypatch, factory, node_type, constructor_name): + created_node = object() + constructor = MagicMock(name=constructor_name, return_value=created_node) + monkeypatch.setattr( + node_factory, + "resolve_workflow_node_class", + MagicMock(return_value=constructor), + ) + + if constructor_name == "HumanInputNode": + form_repository = sentinel.form_repository + form_repository_impl = MagicMock(return_value=form_repository) + monkeypatch.setattr( + node_factory, + "HumanInputFormRepositoryImpl", + form_repository_impl, + ) + elif constructor_name == "KnowledgeIndexNode": + index_processor = sentinel.index_processor + summary_index = sentinel.summary_index + monkeypatch.setattr(node_factory, "IndexProcessor", MagicMock(return_value=index_processor)) + monkeypatch.setattr(node_factory, "SummaryIndex", MagicMock(return_value=summary_index)) + + node_config = {"id": "node-id", "data": {"type": node_type.value}} + result = factory.create_node(node_config) + + assert result is created_node + kwargs = constructor.call_args.kwargs + assert kwargs["id"] == "node-id" + _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type) + assert kwargs["graph_init_params"] is sentinel.graph_init_params + assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + + if constructor_name == "CodeNode": + assert kwargs["code_executor"] is sentinel.code_executor + assert kwargs["code_limits"] is sentinel.code_limits + elif constructor_name == "TemplateTransformNode": + assert kwargs["template_renderer"] is sentinel.template_renderer + assert kwargs["max_output_length"] == 2048 + elif constructor_name == "HttpRequestNode": + assert kwargs["http_request_config"] is sentinel.http_request_config + assert kwargs["http_client"] is sentinel.http_client + assert kwargs["tool_file_manager_factory"] is sentinel.tool_file_manager_factory + assert kwargs["file_manager"] is sentinel.file_manager + elif constructor_name == "HumanInputNode": + assert kwargs["form_repository"] is form_repository + form_repository_impl.assert_called_once_with(tenant_id="tenant-id") + elif constructor_name == "KnowledgeIndexNode": + assert kwargs["index_processor"] is index_processor + assert kwargs["summary_index_service"] is summary_index + elif constructor_name == "DatasourceNode": + assert kwargs["datasource_manager"] is node_factory.DatasourceManager + elif constructor_name == "KnowledgeRetrievalNode": + assert kwargs["rag_retrieval"] is sentinel.rag_retrieval + elif constructor_name == "DocumentExtractorNode": + assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config + assert kwargs["http_client"] is sentinel.http_client + + @pytest.mark.parametrize( + ("node_type", "constructor_name", "expected_extra_kwargs"), + [ + (NodeType.LLM, "LLMNode", {"http_client": sentinel.http_client}), + (NodeType.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}), + (NodeType.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}), + ], + ) + def test_creates_model_backed_nodes( + self, + monkeypatch, + factory, + node_type, + constructor_name, + expected_extra_kwargs, + ): + created_node = object() + constructor = MagicMock(name=constructor_name, return_value=created_node) + monkeypatch.setattr( + node_factory, + "resolve_workflow_node_class", + MagicMock(return_value=constructor), + ) + llm_init_kwargs = { + "credentials_provider": sentinel.credentials_provider, + "model_factory": sentinel.model_factory, + "model_instance": sentinel.model_instance, + "memory": sentinel.memory, + **expected_extra_kwargs, + } + build_llm_init_kwargs = MagicMock(return_value=llm_init_kwargs) + factory._build_llm_compatible_node_init_kwargs = build_llm_init_kwargs + + node_config = {"id": "node-id", "data": {"type": node_type.value}} + result = factory.create_node(node_config) + + assert result is created_node + build_llm_init_kwargs.assert_called_once() + helper_kwargs = build_llm_init_kwargs.call_args.kwargs + assert helper_kwargs["node_class"] is constructor + assert isinstance(helper_kwargs["node_data"], BaseNodeData) + assert helper_kwargs["node_data"].type == node_type + assert helper_kwargs["include_http_client"] is (node_type != NodeType.PARAMETER_EXTRACTOR) + + constructor_kwargs = constructor.call_args.kwargs + assert constructor_kwargs["id"] == "node-id" + _assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type) + assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params + assert constructor_kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider + assert constructor_kwargs["model_factory"] is sentinel.model_factory + assert constructor_kwargs["model_instance"] is sentinel.model_instance + assert constructor_kwargs["memory"] is sentinel.memory + for key, value in expected_extra_kwargs.items(): + assert constructor_kwargs[key] is value + + +class TestDifyNodeFactoryModelInstance: + @pytest.fixture + def factory(self): + factory = object.__new__(node_factory.DifyNodeFactory) + factory._llm_credentials_provider = MagicMock() + factory._llm_model_factory = MagicMock() + return factory + + @pytest.fixture + def llm_model_setup(self, factory): + def _configure( + *, + completion_params=None, + has_provider_model=True, + model_schema=sentinel.model_schema, + ): + credentials = {"api_key": "secret"} + node_data_model = SimpleNamespace( + provider="provider", + name="model", + mode="chat", + completion_params=completion_params or {}, + ) + node_data = SimpleNamespace(model=node_data_model) + provider_model = MagicMock() if has_provider_model else None + provider_model_bundle = SimpleNamespace( + configuration=SimpleNamespace(get_provider_model=MagicMock(return_value=provider_model)) + ) + model_type_instance = MagicMock() + model_type_instance.get_model_schema.return_value = model_schema + model_instance = SimpleNamespace( + provider_model_bundle=provider_model_bundle, + model_type_instance=model_type_instance, + provider=None, + model_name=None, + credentials=None, + parameters=None, + stop=None, + ) + factory._llm_credentials_provider.fetch.return_value = credentials + factory._llm_model_factory.init_model_instance.return_value = model_instance + return SimpleNamespace( + node_data=node_data, + credentials=credentials, + provider_model=provider_model, + model_type_instance=model_type_instance, + model_instance=model_instance, + ) + + return _configure + + def test_requires_llm_mode(self, factory): + node_data = SimpleNamespace( + model=SimpleNamespace( + provider="provider", + name="model", + mode="", + completion_params={}, + ) + ) + + with pytest.raises(node_factory.LLMModeRequiredError, match="LLM mode is required"): + factory._build_model_instance_for_llm_node(node_data) + + def test_raises_when_provider_model_is_missing(self, factory, llm_model_setup): + setup = llm_model_setup(has_provider_model=False) + + with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"): + factory._build_model_instance_for_llm_node(setup.node_data) + + def test_raises_when_model_schema_is_missing(self, factory, llm_model_setup): + setup = llm_model_setup(model_schema=None) + + with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"): + factory._build_model_instance_for_llm_node(setup.node_data) + + setup.provider_model.raise_for_status.assert_called_once() + + def test_builds_model_instance_and_normalizes_stop_tokens(self, factory, llm_model_setup): + setup = llm_model_setup( + completion_params={"temperature": 0.3, "stop": "not-a-list"}, + model_schema={"schema": "value"}, + ) + + result = factory._build_model_instance_for_llm_node(setup.node_data) + + assert result is setup.model_instance + assert result.provider == "provider" + assert result.model_name == "model" + assert result.credentials == setup.credentials + assert result.parameters == {"temperature": 0.3} + assert result.stop == () + assert result.model_type_instance is setup.model_type_instance + setup.provider_model.raise_for_status.assert_called_once() + + +class TestDifyNodeFactoryMemory: + @pytest.fixture + def factory(self): + factory = object.__new__(node_factory.DifyNodeFactory) + factory._dify_context = SimpleNamespace(app_id="app-id") + factory.graph_runtime_state = SimpleNamespace(variable_pool=MagicMock()) + return factory + + def test_returns_none_when_memory_is_not_configured(self, factory): + result = factory._build_memory_for_llm_node( + node_data=SimpleNamespace(memory=None), + model_instance=sentinel.model_instance, + ) + + assert result is None + factory.graph_runtime_state.variable_pool.get.assert_not_called() + + def test_uses_string_segment_conversation_id(self, monkeypatch, factory): + memory_config = sentinel.memory_config + factory.graph_runtime_state.variable_pool.get.return_value = StringSegment(value="conversation-id") + fetch_memory = MagicMock(return_value=sentinel.memory) + monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory) + + result = factory._build_memory_for_llm_node( + node_data=SimpleNamespace(memory=memory_config), + model_instance=sentinel.model_instance, + ) + + assert result is sentinel.memory + factory.graph_runtime_state.variable_pool.get.assert_called_once_with( + ["sys", SystemVariableKey.CONVERSATION_ID] + ) + fetch_memory.assert_called_once_with( + conversation_id="conversation-id", + app_id="app-id", + node_data_memory=memory_config, + model_instance=sentinel.model_instance, + ) + + def test_ignores_non_string_segment_conversation_ids(self, monkeypatch, factory): + memory_config = sentinel.memory_config + factory.graph_runtime_state.variable_pool.get.return_value = sentinel.segment + fetch_memory = MagicMock(return_value=sentinel.memory) + monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory) + + result = factory._build_memory_for_llm_node( + node_data=SimpleNamespace(memory=memory_config), + model_instance=sentinel.model_instance, + ) + + assert result is sentinel.memory + fetch_memory.assert_called_once_with( + conversation_id=None, + app_id="app-id", + node_data_memory=memory_config, + model_instance=sentinel.model_instance, + ) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 0aa6ec3f45..93ba7f3333 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -9,6 +9,7 @@ from dify_graph.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.file.enums import FileType from dify_graph.file.models import File, FileTransferMethod from dify_graph.nodes.code.code_node import CodeNode @@ -124,7 +125,7 @@ class TestWorkflowEntry: def get_node_config_by_id(self, target_id: str): assert target_id == node_id - return node_config + return NodeConfigDictAdapter.validate_python(node_config) workflow = StubWorkflow() variable_pool = VariablePool(system_variables=SystemVariable.default(), user_inputs={}) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py new file mode 100644 index 0000000000..68e42894fc --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -0,0 +1,656 @@ +from collections import UserString +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, sentinel + +import pytest + +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow import workflow_entry +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.errors import WorkflowNodeRunFailedError +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File +from dify_graph.graph_events import GraphRunFailedEvent +from dify_graph.nodes import NodeType +from dify_graph.runtime import ChildGraphNotFoundError + + +def _build_typed_node_config(node_type: NodeType): + return NodeConfigDictAdapter.validate_python({"id": "node-id", "data": {"type": node_type}}) + + +class TestWorkflowChildEngineBuilder: + @pytest.mark.parametrize( + ("graph_config", "node_id", "expected"), + [ + ({"nodes": [{"id": "root"}]}, "root", True), + ({"nodes": [{"id": "root"}]}, "other", False), + ({"nodes": "invalid"}, "root", None), + ({"nodes": ["invalid"]}, "root", None), + ], + ) + def test_has_node_id(self, graph_config, node_id, expected): + result = workflow_entry._WorkflowChildEngineBuilder._has_node_id(graph_config, node_id) + + assert result is expected + + def test_build_child_engine_raises_when_root_node_is_missing(self): + builder = workflow_entry._WorkflowChildEngineBuilder() + + with patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory): + with pytest.raises(ChildGraphNotFoundError, match="child graph root node 'missing' not found"): + builder.build_child_engine( + workflow_id="workflow-id", + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + graph_config={"nodes": []}, + root_node_id="missing", + ) + + def test_build_child_engine_constructs_graph_engine_and_layers(self): + builder = workflow_entry._WorkflowChildEngineBuilder() + child_graph = sentinel.child_graph + child_engine = MagicMock() + quota_layer = sentinel.quota_layer + additional_layers = [sentinel.layer_one, sentinel.layer_two] + + with ( + patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory) as dify_node_factory, + patch.object(workflow_entry.Graph, "init", return_value=child_graph) as graph_init, + patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls, + patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config), + patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel), + patch.object(workflow_entry, "LLMQuotaLayer", return_value=quota_layer), + ): + result = builder.build_child_engine( + workflow_id="workflow-id", + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + graph_config={"nodes": [{"id": "root"}]}, + root_node_id="root", + layers=additional_layers, + ) + + assert result is child_engine + dify_node_factory.assert_called_once_with( + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + ) + graph_init.assert_called_once_with( + graph_config={"nodes": [{"id": "root"}]}, + node_factory=sentinel.factory, + root_node_id="root", + ) + graph_engine_cls.assert_called_once_with( + workflow_id="workflow-id", + graph=child_graph, + graph_runtime_state=sentinel.graph_runtime_state, + command_channel=sentinel.command_channel, + config=sentinel.graph_engine_config, + child_engine_builder=builder, + ) + assert child_engine.layer.call_args_list == [ + ((quota_layer,), {}), + ((sentinel.layer_one,), {}), + ((sentinel.layer_two,), {}), + ] + + +class TestWorkflowEntryInit: + def test_rejects_call_depth_above_limit(self): + call_depth = workflow_entry.dify_config.WORKFLOW_CALL_MAX_DEPTH + 1 + + with pytest.raises(ValueError, match="Max workflow call depth"): + workflow_entry.WorkflowEntry( + tenant_id="tenant-id", + app_id="app-id", + workflow_id="workflow-id", + graph_config={"nodes": [], "edges": []}, + graph=sentinel.graph, + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=call_depth, + variable_pool=sentinel.variable_pool, + graph_runtime_state=sentinel.graph_runtime_state, + ) + + def test_applies_debug_and_observability_layers(self): + graph_engine = MagicMock() + debug_layer = sentinel.debug_layer + execution_limits_layer = sentinel.execution_limits_layer + llm_quota_layer = sentinel.llm_quota_layer + observability_layer = sentinel.observability_layer + + with ( + patch.object(workflow_entry.dify_config, "DEBUG", True), + patch.object(workflow_entry.dify_config, "ENABLE_OTEL", False), + patch.object(workflow_entry, "is_instrument_flag_enabled", return_value=True), + patch.object(workflow_entry, "GraphEngine", return_value=graph_engine) as graph_engine_cls, + patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config), + patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel), + patch.object(workflow_entry, "DebugLoggingLayer", return_value=debug_layer) as debug_logging_layer, + patch.object( + workflow_entry, + "ExecutionLimitsLayer", + return_value=execution_limits_layer, + ) as execution_limits_layer_cls, + patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer), + patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer), + ): + entry = workflow_entry.WorkflowEntry( + tenant_id="tenant-id", + app_id="app-id", + workflow_id="workflow-id-123456", + graph_config={"nodes": [], "edges": []}, + graph=sentinel.graph, + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + variable_pool=sentinel.variable_pool, + graph_runtime_state=sentinel.graph_runtime_state, + command_channel=None, + ) + + assert entry.command_channel is sentinel.command_channel + graph_engine_cls.assert_called_once_with( + workflow_id="workflow-id-123456", + graph=sentinel.graph, + graph_runtime_state=sentinel.graph_runtime_state, + command_channel=sentinel.command_channel, + config=sentinel.graph_engine_config, + child_engine_builder=entry._child_engine_builder, + ) + debug_logging_layer.assert_called_once_with( + level="DEBUG", + include_inputs=True, + include_outputs=True, + include_process_data=False, + logger_name="GraphEngine.Debug.workflow", + ) + execution_limits_layer_cls.assert_called_once_with( + max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME, + ) + assert graph_engine.layer.call_args_list == [ + ((debug_layer,), {}), + ((execution_limits_layer,), {}), + ((llm_quota_layer,), {}), + ((observability_layer,), {}), + ] + + +class TestWorkflowEntryRun: + def test_run_swallows_generate_task_stopped_errors(self): + entry = object.__new__(workflow_entry.WorkflowEntry) + entry.graph_engine = MagicMock() + entry.graph_engine.run.side_effect = GenerateTaskStoppedError() + + assert list(entry.run()) == [] + + def test_run_emits_failed_event_for_unexpected_errors(self): + entry = object.__new__(workflow_entry.WorkflowEntry) + entry.graph_engine = MagicMock() + entry.graph_engine.run.side_effect = RuntimeError("boom") + + events = list(entry.run()) + + assert len(events) == 1 + assert isinstance(events[0], GraphRunFailedEvent) + assert events[0].error == "boom" + + +class TestWorkflowEntrySingleStepRun: + def test_uses_empty_mapping_when_selector_extraction_is_not_implemented(self): + class FakeNode: + id = "node-id" + title = "Node Title" + node_type = "fake" + + @staticmethod + def version(): + return "1" + + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + raise NotImplementedError + + with ( + patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "mapping_user_inputs_to_variable_pool", + ) as mapping_user_inputs_to_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "_traced_node_run", + return_value=iter(["event"]), + ), + ): + dify_node_factory.return_value.create_node.return_value = FakeNode() + workflow = SimpleNamespace( + tenant_id="tenant-id", + app_id="app-id", + id="workflow-id", + graph_dict={"nodes": [], "edges": []}, + get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START), + ) + + node, generator = workflow_entry.WorkflowEntry.single_step_run( + workflow=workflow, + node_id="node-id", + user_id="user-id", + user_inputs={"question": "hello"}, + variable_pool=sentinel.variable_pool, + ) + + assert node.id == "node-id" + assert list(generator) == ["event"] + load_into_variable_pool.assert_called_once_with( + variable_loader=workflow_entry.DUMMY_VARIABLE_LOADER, + variable_pool=sentinel.variable_pool, + variable_mapping={}, + user_inputs={"question": "hello"}, + ) + mapping_user_inputs_to_variable_pool.assert_called_once_with( + variable_mapping={}, + user_inputs={"question": "hello"}, + variable_pool=sentinel.variable_pool, + tenant_id="tenant-id", + ) + + def test_skips_user_input_mapping_for_datasource_nodes(self): + class FakeDatasourceNode: + id = "node-id" + node_type = "datasource" + + @staticmethod + def version(): + return "1" + + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + return {"question": ["node", "question"]} + + with ( + patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "mapping_user_inputs_to_variable_pool", + ) as mapping_user_inputs_to_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "_traced_node_run", + return_value=iter(["event"]), + ), + ): + dify_node_factory.return_value.create_node.return_value = FakeDatasourceNode() + workflow = SimpleNamespace( + tenant_id="tenant-id", + app_id="app-id", + id="workflow-id", + graph_dict={"nodes": [], "edges": []}, + get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.DATASOURCE), + ) + + node, generator = workflow_entry.WorkflowEntry.single_step_run( + workflow=workflow, + node_id="node-id", + user_id="user-id", + user_inputs={"question": "hello"}, + variable_pool=sentinel.variable_pool, + ) + + assert node.id == "node-id" + assert list(generator) == ["event"] + load_into_variable_pool.assert_called_once() + mapping_user_inputs_to_variable_pool.assert_not_called() + + def test_wraps_traced_node_run_failures(self): + class FakeNode: + id = "node-id" + title = "Node Title" + node_type = "fake" + + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + return {} + + @staticmethod + def version(): + return "1" + + with ( + patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry, "load_into_variable_pool"), + patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"), + patch.object( + workflow_entry.WorkflowEntry, + "_traced_node_run", + side_effect=RuntimeError("boom"), + ), + ): + dify_node_factory.return_value.create_node.return_value = FakeNode() + workflow = SimpleNamespace( + tenant_id="tenant-id", + app_id="app-id", + id="workflow-id", + graph_dict={"nodes": [], "edges": []}, + get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START), + ) + + with pytest.raises(WorkflowNodeRunFailedError): + workflow_entry.WorkflowEntry.single_step_run( + workflow=workflow, + node_id="node-id", + user_id="user-id", + user_inputs={}, + variable_pool=sentinel.variable_pool, + ) + + +class TestWorkflowEntryHelpers: + def test_create_single_node_graph_builds_start_edge(self): + graph = workflow_entry.WorkflowEntry._create_single_node_graph( + node_id="target-node", + node_data={"type": NodeType.PARAMETER_EXTRACTOR}, + node_width=320, + node_height=180, + ) + + assert graph["nodes"][0]["id"] == "start" + assert graph["nodes"][1]["id"] == "target-node" + assert graph["nodes"][1]["width"] == 320 + assert graph["nodes"][1]["height"] == 180 + assert graph["edges"] == [ + { + "source": "start", + "target": "target-node", + "sourceHandle": "source", + "targetHandle": "target", + } + ] + + def test_run_free_node_rejects_unsupported_types(self): + with pytest.raises(ValueError, match="Node type start not supported"): + workflow_entry.WorkflowEntry.run_free_node( + node_data={"type": NodeType.START.value}, + node_id="node-id", + tenant_id="tenant-id", + user_id="user-id", + user_inputs={}, + ) + + def test_run_free_node_rejects_missing_node_class(self, monkeypatch): + monkeypatch.setattr( + workflow_entry, + "resolve_workflow_node_class", + MagicMock(return_value=None), + ) + + with pytest.raises(ValueError, match="Node class not found for node type parameter-extractor"): + workflow_entry.WorkflowEntry.run_free_node( + node_data={"type": NodeType.PARAMETER_EXTRACTOR.value}, + node_id="node-id", + tenant_id="tenant-id", + user_id="user-id", + user_inputs={}, + ) + + def test_run_free_node_uses_empty_mapping_when_selector_extraction_is_not_implemented(self, monkeypatch): + class FakeNodeClass: + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + raise NotImplementedError + + class FakeNode: + id = "node-id" + title = "Node Title" + node_type = "parameter-extractor" + + @staticmethod + def version(): + return "1" + + dify_node_factory = MagicMock() + dify_node_factory.create_node.return_value = FakeNode() + monkeypatch.setattr( + workflow_entry, + "resolve_workflow_node_class", + MagicMock(return_value=FakeNodeClass), + ) + + with ( + patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables), + patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls, + patch.object( + workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params + ) as graph_init_params, + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object( + workflow_entry, "build_dify_run_context", return_value={"_dify": "context"} + ) as build_dify_run_context, + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory) as dify_node_factory_cls, + patch.object( + workflow_entry.WorkflowEntry, + "mapping_user_inputs_to_variable_pool", + ) as mapping_user_inputs_to_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "_traced_node_run", + return_value=iter(["event"]), + ), + ): + node, generator = workflow_entry.WorkflowEntry.run_free_node( + node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"}, + node_id="node-id", + tenant_id="tenant-id", + user_id="user-id", + user_inputs={"question": "hello"}, + ) + + assert node.id == "node-id" + assert list(generator) == ["event"] + variable_pool_cls.assert_called_once_with( + system_variables=sentinel.system_variables, + user_inputs={}, + environment_variables=[], + ) + build_dify_run_context.assert_called_once_with( + tenant_id="tenant-id", + app_id="", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + graph_init_params.assert_called_once_with( + workflow_id="", + graph_config=workflow_entry.WorkflowEntry._create_single_node_graph( + "node-id", {"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"} + ), + run_context={"_dify": "context"}, + call_depth=0, + ) + dify_node_factory_cls.assert_called_once_with( + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + ) + mapping_user_inputs_to_variable_pool.assert_called_once_with( + variable_mapping={}, + user_inputs={"question": "hello"}, + variable_pool=sentinel.variable_pool, + tenant_id="tenant-id", + ) + + def test_run_free_node_wraps_execution_failures(self, monkeypatch): + class FakeNodeClass: + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + return {} + + class FakeNode: + id = "node-id" + title = "Node Title" + node_type = "parameter-extractor" + + @staticmethod + def version(): + return "1" + + dify_node_factory = MagicMock() + dify_node_factory.create_node.return_value = FakeNode() + monkeypatch.setattr( + workflow_entry, + "resolve_workflow_node_class", + MagicMock(return_value=FakeNodeClass), + ) + + with ( + patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables), + patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool), + patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory), + patch.object( + workflow_entry.WorkflowEntry, + "mapping_user_inputs_to_variable_pool", + side_effect=RuntimeError("boom"), + ), + ): + with pytest.raises(WorkflowNodeRunFailedError, match="Node Title run failed: boom"): + workflow_entry.WorkflowEntry.run_free_node( + node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"}, + node_id="node-id", + tenant_id="tenant-id", + user_id="user-id", + user_inputs={"question": "hello"}, + ) + + def test_handle_special_values_serializes_nested_files(self): + file = File( + tenant_id="tenant-id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.png", + filename="image.png", + extension=".png", + ) + + result = workflow_entry.WorkflowEntry.handle_special_values({"file": file, "nested": {"files": [file]}}) + + assert result == { + "file": file.to_dict(), + "nested": {"files": [file.to_dict()]}, + } + + def test_handle_special_values_returns_none_for_none(self): + assert workflow_entry.WorkflowEntry._handle_special_values(None) is None + + def test_handle_special_values_returns_scalar_as_is(self): + assert workflow_entry.WorkflowEntry._handle_special_values("plain-text") == "plain-text" + + +class TestMappingUserInputsBranches: + def test_rejects_invalid_node_variable_key(self): + class EmptySplitKey(UserString): + def split(self, _sep=None): + return [] + + with pytest.raises(ValueError, match="Invalid node variable broken"): + workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping={EmptySplitKey("broken"): ["node", "input"]}, + user_inputs={}, + variable_pool=MagicMock(), + tenant_id="tenant-id", + ) + + def test_skips_none_user_input_when_variable_already_exists(self): + variable_pool = MagicMock() + variable_pool.get.return_value = None + + workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping={"node.input": ["target", "input"]}, + user_inputs={"node.input": None}, + variable_pool=variable_pool, + tenant_id="tenant-id", + ) + + variable_pool.add.assert_not_called() + + def test_merges_structured_output_values(self): + variable_pool = MagicMock() + variable_pool.get.side_effect = [ + None, + SimpleNamespace(value={"existing": "value"}), + ] + + workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping={"node.answer": ["target", "structured_output", "answer"]}, + user_inputs={"node.answer": "new-value"}, + variable_pool=variable_pool, + tenant_id="tenant-id", + ) + + variable_pool.add.assert_called_once_with( + ["target", "structured_output"], + {"existing": "value", "answer": "new-value"}, + ) + + +class TestWorkflowEntryTracing: + def test_traced_node_run_reports_success(self): + layer = MagicMock() + + class FakeNode: + def ensure_execution_id(self): + return None + + def run(self): + yield "event" + + with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer): + events = list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode())) + + assert events == ["event"] + layer.on_graph_start.assert_called_once_with() + layer.on_node_run_start.assert_called_once() + layer.on_node_run_end.assert_called_once_with( + layer.on_node_run_start.call_args.args[0], + None, + ) + + def test_traced_node_run_reports_errors(self): + layer = MagicMock() + + class FakeNode: + def ensure_execution_id(self): + return None + + def run(self): + raise RuntimeError("boom") + yield + + with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer): + with pytest.raises(RuntimeError, match="boom"): + list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode())) + + assert isinstance(layer.on_node_run_end.call_args.args[1], RuntimeError) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/__base/__init__.py b/api/tests/unit_tests/dify_graph/model_runtime/__base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/dify_graph/model_runtime/__base/test_increase_tool_call.py similarity index 100% rename from api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py rename to api/tests/unit_tests/dify_graph/model_runtime/__base/test_increase_tool_call.py diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/dify_graph/model_runtime/__base/test_large_language_model_non_stream_parsing.py similarity index 100% rename from api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py rename to api/tests/unit_tests/dify_graph/model_runtime/__base/test_large_language_model_non_stream_parsing.py diff --git a/api/tests/unit_tests/dify_graph/model_runtime/__init__.py b/api/tests/unit_tests/dify_graph/model_runtime/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_base_callback.py b/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_base_callback.py new file mode 100644 index 0000000000..2410d16d63 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_base_callback.py @@ -0,0 +1,964 @@ +"""Comprehensive unit tests for core/model_runtime/callbacks/base_callback.py""" + +from unittest.mock import MagicMock, patch + +import pytest + +from dify_graph.model_runtime.callbacks.base_callback import ( + _TEXT_COLOR_MAPPING, + Callback, +) +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool + +# --------------------------------------------------------------------------- +# Concrete implementation of the abstract Callback for testing +# --------------------------------------------------------------------------- + + +class ConcreteCallback(Callback): + """A minimal concrete subclass that satisfies all abstract methods.""" + + def __init__(self, raise_error: bool = False): + self.raise_error = raise_error + # Track invocations + self.before_invoke_calls: list[dict] = [] + self.new_chunk_calls: list[dict] = [] + self.after_invoke_calls: list[dict] = [] + self.invoke_error_calls: list[dict] = [] + + def on_before_invoke( + self, + llm_instance, + model, + credentials, + prompt_messages, + model_parameters, + tools=None, + stop=None, + stream=True, + user=None, + ): + self.before_invoke_calls.append( + { + "llm_instance": llm_instance, + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + "user": user, + } + ) + # To cover the 'raise NotImplementedError()' in the base class + try: + super().on_before_invoke( + llm_instance, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user + ) + except NotImplementedError: + pass + + def on_new_chunk( + self, + llm_instance, + chunk, + model, + credentials, + prompt_messages, + model_parameters, + tools=None, + stop=None, + stream=True, + user=None, + ): + self.new_chunk_calls.append( + { + "llm_instance": llm_instance, + "chunk": chunk, + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + "user": user, + } + ) + try: + super().on_new_chunk( + llm_instance, chunk, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user + ) + except NotImplementedError: + pass + + def on_after_invoke( + self, + llm_instance, + result, + model, + credentials, + prompt_messages, + model_parameters, + tools=None, + stop=None, + stream=True, + user=None, + ): + self.after_invoke_calls.append( + { + "llm_instance": llm_instance, + "result": result, + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + "user": user, + } + ) + try: + super().on_after_invoke( + llm_instance, result, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user + ) + except NotImplementedError: + pass + + def on_invoke_error( + self, + llm_instance, + ex, + model, + credentials, + prompt_messages, + model_parameters, + tools=None, + stop=None, + stream=True, + user=None, + ): + self.invoke_error_calls.append( + { + "llm_instance": llm_instance, + "ex": ex, + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + "user": user, + } + ) + try: + super().on_invoke_error( + llm_instance, ex, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user + ) + except NotImplementedError: + pass + + +# --------------------------------------------------------------------------- +# A subclass that deliberately leaves abstract methods un-implemented, +# used to verify that instantiation raises TypeError. +# --------------------------------------------------------------------------- + + +# =========================================================================== +# Tests for _TEXT_COLOR_MAPPING module-level constant +# =========================================================================== + + +class TestTextColorMapping: + """Tests for the module-level _TEXT_COLOR_MAPPING dictionary.""" + + def test_contains_all_expected_colors(self): + expected_keys = {"blue", "yellow", "pink", "green", "red"} + assert set(_TEXT_COLOR_MAPPING.keys()) == expected_keys + + def test_blue_escape_code(self): + assert _TEXT_COLOR_MAPPING["blue"] == "36;1" + + def test_yellow_escape_code(self): + assert _TEXT_COLOR_MAPPING["yellow"] == "33;1" + + def test_pink_escape_code(self): + assert _TEXT_COLOR_MAPPING["pink"] == "38;5;200" + + def test_green_escape_code(self): + assert _TEXT_COLOR_MAPPING["green"] == "32;1" + + def test_red_escape_code(self): + assert _TEXT_COLOR_MAPPING["red"] == "31;1" + + def test_mapping_is_dict(self): + assert isinstance(_TEXT_COLOR_MAPPING, dict) + + def test_all_values_are_strings(self): + for key, value in _TEXT_COLOR_MAPPING.items(): + assert isinstance(value, str), f"Value for {key!r} should be str" + + +# =========================================================================== +# Tests for the Callback ABC itself +# =========================================================================== + + +class TestCallbackAbstract: + """Tests verifying Callback is a proper ABC.""" + + def test_cannot_instantiate_abstract_class_directly(self): + """Callback cannot be instantiated since it has abstract methods.""" + with pytest.raises(TypeError): + Callback() # type: ignore[abstract] + + def test_concrete_subclass_can_be_instantiated(self): + cb = ConcreteCallback() + assert isinstance(cb, Callback) + + def test_default_raise_error_is_false(self): + cb = ConcreteCallback() + assert cb.raise_error is False + + def test_raise_error_can_be_set_to_true(self): + cb = ConcreteCallback(raise_error=True) + assert cb.raise_error is True + + def test_subclass_missing_on_before_invoke_raises_type_error(self): + """A subclass missing any single abstract method cannot be instantiated.""" + + class IncompleteCallback(Callback): + def on_new_chunk(self, *a, **kw): ... + def on_after_invoke(self, *a, **kw): ... + def on_invoke_error(self, *a, **kw): ... + + with pytest.raises(TypeError): + IncompleteCallback() # type: ignore[abstract] + + def test_subclass_missing_on_new_chunk_raises_type_error(self): + class IncompleteCallback(Callback): + def on_before_invoke(self, *a, **kw): ... + def on_after_invoke(self, *a, **kw): ... + def on_invoke_error(self, *a, **kw): ... + + with pytest.raises(TypeError): + IncompleteCallback() # type: ignore[abstract] + + def test_subclass_missing_on_after_invoke_raises_type_error(self): + class IncompleteCallback(Callback): + def on_before_invoke(self, *a, **kw): ... + def on_new_chunk(self, *a, **kw): ... + def on_invoke_error(self, *a, **kw): ... + + with pytest.raises(TypeError): + IncompleteCallback() # type: ignore[abstract] + + def test_subclass_missing_on_invoke_error_raises_type_error(self): + class IncompleteCallback(Callback): + def on_before_invoke(self, *a, **kw): ... + def on_new_chunk(self, *a, **kw): ... + def on_after_invoke(self, *a, **kw): ... + + with pytest.raises(TypeError): + IncompleteCallback() # type: ignore[abstract] + + +# =========================================================================== +# Tests for on_before_invoke +# =========================================================================== + + +class TestOnBeforeInvoke: + """Tests for the on_before_invoke callback method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + self.llm_instance = MagicMock() + self.model = "gpt-4" + self.credentials = {"api_key": "sk-test"} + self.prompt_messages = [MagicMock(spec=PromptMessage)] + self.model_parameters = {"temperature": 0.7} + + def test_on_before_invoke_called_with_required_args(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.before_invoke_calls) == 1 + call = self.cb.before_invoke_calls[0] + assert call["llm_instance"] is self.llm_instance + assert call["model"] == self.model + assert call["credentials"] == self.credentials + assert call["prompt_messages"] is self.prompt_messages + assert call["model_parameters"] is self.model_parameters + + def test_on_before_invoke_defaults_tools_none(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.before_invoke_calls[0]["tools"] is None + + def test_on_before_invoke_defaults_stop_none(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.before_invoke_calls[0]["stop"] is None + + def test_on_before_invoke_defaults_stream_true(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.before_invoke_calls[0]["stream"] is True + + def test_on_before_invoke_defaults_user_none(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.before_invoke_calls[0]["user"] is None + + def test_on_before_invoke_with_all_optional_args(self): + tools = [MagicMock(spec=PromptMessageTool)] + stop = ["stop1", "stop2"] + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + tools=tools, + stop=stop, + stream=False, + user="user-123", + ) + call = self.cb.before_invoke_calls[0] + assert call["tools"] is tools + assert call["stop"] == stop + assert call["stream"] is False + assert call["user"] == "user-123" + + def test_on_before_invoke_called_multiple_times(self): + for i in range(3): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=f"model-{i}", + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.before_invoke_calls) == 3 + assert self.cb.before_invoke_calls[2]["model"] == "model-2" + + +# =========================================================================== +# Tests for on_new_chunk +# =========================================================================== + + +class TestOnNewChunk: + """Tests for the on_new_chunk callback method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + self.llm_instance = MagicMock() + self.chunk = MagicMock(spec=LLMResultChunk) + self.model = "gpt-3.5-turbo" + self.credentials = {"api_key": "sk-test"} + self.prompt_messages = [MagicMock(spec=PromptMessage)] + self.model_parameters = {"max_tokens": 256} + + def test_on_new_chunk_called_with_required_args(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.new_chunk_calls) == 1 + call = self.cb.new_chunk_calls[0] + assert call["llm_instance"] is self.llm_instance + assert call["chunk"] is self.chunk + assert call["model"] == self.model + assert call["credentials"] == self.credentials + + def test_on_new_chunk_defaults_tools_none(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.new_chunk_calls[0]["tools"] is None + + def test_on_new_chunk_defaults_stop_none(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.new_chunk_calls[0]["stop"] is None + + def test_on_new_chunk_defaults_stream_true(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.new_chunk_calls[0]["stream"] is True + + def test_on_new_chunk_defaults_user_none(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.new_chunk_calls[0]["user"] is None + + def test_on_new_chunk_with_all_optional_args(self): + tools = [MagicMock(spec=PromptMessageTool)] + stop = ["END"] + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + tools=tools, + stop=stop, + stream=False, + user="chunk-user", + ) + call = self.cb.new_chunk_calls[0] + assert call["tools"] is tools + assert call["stop"] == stop + assert call["stream"] is False + assert call["user"] == "chunk-user" + + def test_on_new_chunk_called_multiple_times(self): + for i in range(5): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.new_chunk_calls) == 5 + + +# =========================================================================== +# Tests for on_after_invoke +# =========================================================================== + + +class TestOnAfterInvoke: + """Tests for the on_after_invoke callback method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + self.llm_instance = MagicMock() + self.result = MagicMock(spec=LLMResult) + self.model = "claude-3" + self.credentials = {"api_key": "anthropic-key"} + self.prompt_messages = [MagicMock(spec=PromptMessage)] + self.model_parameters = {"temperature": 1.0} + + def test_on_after_invoke_called_with_required_args(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.after_invoke_calls) == 1 + call = self.cb.after_invoke_calls[0] + assert call["llm_instance"] is self.llm_instance + assert call["result"] is self.result + assert call["model"] == self.model + assert call["credentials"] is self.credentials + + def test_on_after_invoke_defaults_tools_none(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.after_invoke_calls[0]["tools"] is None + + def test_on_after_invoke_defaults_stop_none(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.after_invoke_calls[0]["stop"] is None + + def test_on_after_invoke_defaults_stream_true(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.after_invoke_calls[0]["stream"] is True + + def test_on_after_invoke_defaults_user_none(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.after_invoke_calls[0]["user"] is None + + def test_on_after_invoke_with_all_optional_args(self): + tools = [MagicMock(spec=PromptMessageTool)] + stop = ["STOP"] + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + tools=tools, + stop=stop, + stream=False, + user="after-user", + ) + call = self.cb.after_invoke_calls[0] + assert call["tools"] is tools + assert call["stop"] == stop + assert call["stream"] is False + assert call["user"] == "after-user" + + +# =========================================================================== +# Tests for on_invoke_error +# =========================================================================== + + +class TestOnInvokeError: + """Tests for the on_invoke_error callback method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + self.llm_instance = MagicMock() + self.ex = ValueError("something went wrong") + self.model = "gemini-pro" + self.credentials = {"api_key": "google-key"} + self.prompt_messages = [MagicMock(spec=PromptMessage)] + self.model_parameters = {"top_p": 0.9} + + def test_on_invoke_error_called_with_required_args(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.invoke_error_calls) == 1 + call = self.cb.invoke_error_calls[0] + assert call["llm_instance"] is self.llm_instance + assert call["ex"] is self.ex + assert call["model"] == self.model + assert call["credentials"] is self.credentials + + def test_on_invoke_error_defaults_tools_none(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.invoke_error_calls[0]["tools"] is None + + def test_on_invoke_error_defaults_stop_none(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.invoke_error_calls[0]["stop"] is None + + def test_on_invoke_error_defaults_stream_true(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.invoke_error_calls[0]["stream"] is True + + def test_on_invoke_error_defaults_user_none(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.invoke_error_calls[0]["user"] is None + + def test_on_invoke_error_with_all_optional_args(self): + tools = [MagicMock(spec=PromptMessageTool)] + stop = ["HALT"] + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + tools=tools, + stop=stop, + stream=False, + user="error-user", + ) + call = self.cb.invoke_error_calls[0] + assert call["tools"] is tools + assert call["stop"] == stop + assert call["stream"] is False + assert call["user"] == "error-user" + + def test_on_invoke_error_accepts_various_exception_types(self): + for exc in [RuntimeError("r"), KeyError("k"), Exception("e")]: + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=exc, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.invoke_error_calls) == 3 + + +# =========================================================================== +# Tests for print_text (concrete method on Callback) +# =========================================================================== + + +class TestPrintText: + """Tests for the concrete print_text method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + + def test_print_text_without_color_prints_plain_text(self, capsys): + self.cb.print_text("hello world") + captured = capsys.readouterr() + assert captured.out == "hello world" + + def test_print_text_with_color_prints_colored_text(self, capsys): + self.cb.print_text("colored text", color="blue") + captured = capsys.readouterr() + # Should contain ANSI escape sequences + assert "colored text" in captured.out + assert "\001b[" in captured.out or "\033[" in captured.out or "\x1b[" in captured.out + + def test_print_text_without_color_no_ansi(self, capsys): + self.cb.print_text("plain text", color=None) + captured = capsys.readouterr() + assert captured.out == "plain text" + # No ANSI escape sequences + assert "\x1b" not in captured.out + + def test_print_text_default_end_is_empty_string(self, capsys): + self.cb.print_text("no newline") + captured = capsys.readouterr() + assert not captured.out.endswith("\n") + + def test_print_text_with_custom_end(self, capsys): + self.cb.print_text("with newline", end="\n") + captured = capsys.readouterr() + assert captured.out.endswith("\n") + + def test_print_text_with_empty_string(self, capsys): + self.cb.print_text("", color=None) + captured = capsys.readouterr() + assert captured.out == "" + + @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) + def test_print_text_all_colors_work(self, color, capsys): + """Verify no KeyError is thrown for any valid color.""" + self.cb.print_text("test", color=color) + captured = capsys.readouterr() + assert "test" in captured.out + + def test_print_text_calls_get_colored_text_when_color_given(self): + with patch.object(self.cb, "_get_colored_text", return_value="[COLORED]") as mock_gct: + with patch("builtins.print") as mock_print: + self.cb.print_text("hello", color="green") + mock_gct.assert_called_once_with("hello", "green") + mock_print.assert_called_once_with("[COLORED]", end="") + + def test_print_text_does_not_call_get_colored_text_when_no_color(self): + with patch.object(self.cb, "_get_colored_text") as mock_gct: + with patch("builtins.print"): + self.cb.print_text("hello", color=None) + mock_gct.assert_not_called() + + def test_print_text_passes_end_to_print(self): + with patch("builtins.print") as mock_print: + self.cb.print_text("text", end="---") + mock_print.assert_called_once_with("text", end="---") + + +# =========================================================================== +# Tests for _get_colored_text (private helper method) +# =========================================================================== + + +class TestGetColoredText: + """Tests for the _get_colored_text private method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + + @pytest.mark.parametrize(("color", "expected_code"), list(_TEXT_COLOR_MAPPING.items())) + def test_get_colored_text_uses_correct_escape_code(self, color, expected_code): + result = self.cb._get_colored_text("text", color) + assert expected_code in result + + @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) + def test_get_colored_text_contains_input_text(self, color): + result = self.cb._get_colored_text("hello", color) + assert "hello" in result + + @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) + def test_get_colored_text_starts_with_escape(self, color): + result = self.cb._get_colored_text("text", color) + # Should start with an ANSI escape (\x1b or \u001b) + assert result.startswith("\x1b[") or result.startswith("\u001b[") + + @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) + def test_get_colored_text_ends_with_reset(self, color): + result = self.cb._get_colored_text("text", color) + # Should end with the ANSI reset code + assert result.endswith("\x1b[0m") or result.endswith("\u001b[0m") + + def test_get_colored_text_returns_string(self): + result = self.cb._get_colored_text("text", "blue") + assert isinstance(result, str) + + def test_get_colored_text_blue_exact_format(self): + result = self.cb._get_colored_text("hello", "blue") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['blue']}m\033[1;3mhello\u001b[0m" + assert result == expected + + def test_get_colored_text_red_exact_format(self): + result = self.cb._get_colored_text("error", "red") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['red']}m\033[1;3merror\u001b[0m" + assert result == expected + + def test_get_colored_text_green_exact_format(self): + result = self.cb._get_colored_text("ok", "green") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['green']}m\033[1;3mok\u001b[0m" + assert result == expected + + def test_get_colored_text_yellow_exact_format(self): + result = self.cb._get_colored_text("warn", "yellow") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['yellow']}m\033[1;3mwarn\u001b[0m" + assert result == expected + + def test_get_colored_text_pink_exact_format(self): + result = self.cb._get_colored_text("info", "pink") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['pink']}m\033[1;3minfo\u001b[0m" + assert result == expected + + def test_get_colored_text_empty_string(self): + result = self.cb._get_colored_text("", "blue") + assert isinstance(result, str) + # Empty text should still have escape codes + assert _TEXT_COLOR_MAPPING["blue"] in result + + def test_get_colored_text_invalid_color_raises_key_error(self): + with pytest.raises(KeyError): + self.cb._get_colored_text("text", "purple") + + def test_get_colored_text_with_special_characters(self): + special = "hello\nworld\ttab" + result = self.cb._get_colored_text(special, "blue") + assert special in result + + def test_get_colored_text_with_long_text(self): + long_text = "a" * 10000 + result = self.cb._get_colored_text(long_text, "green") + assert long_text in result + + +# =========================================================================== +# Integration-style tests: full workflow through a ConcreteCallback +# =========================================================================== + + +class TestConcreteCallbackIntegration: + """End-to-end workflow tests using ConcreteCallback.""" + + def test_full_invocation_lifecycle(self): + """Simulate a complete LLM invocation lifecycle through all callbacks.""" + cb = ConcreteCallback() + llm_instance = MagicMock() + model = "gpt-4o" + credentials = {"api_key": "sk-xyz"} + prompt_messages = [MagicMock(spec=PromptMessage)] + model_parameters = {"temperature": 0.5} + tools = [MagicMock(spec=PromptMessageTool)] + stop = [""] + user = "user-abc" + + # 1. Before invoke + cb.on_before_invoke( + llm_instance=llm_instance, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=True, + user=user, + ) + + # 2. Multiple chunks during streaming + for i in range(3): + chunk = MagicMock(spec=LLMResultChunk) + cb.on_new_chunk( + llm_instance=llm_instance, + chunk=chunk, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=True, + user=user, + ) + + # 3. After invoke + result = MagicMock(spec=LLMResult) + cb.on_after_invoke( + llm_instance=llm_instance, + result=result, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=True, + user=user, + ) + + assert len(cb.before_invoke_calls) == 1 + assert len(cb.new_chunk_calls) == 3 + assert len(cb.after_invoke_calls) == 1 + assert len(cb.invoke_error_calls) == 0 + + def test_error_lifecycle(self): + """Simulate an invoke that results in an error.""" + cb = ConcreteCallback() + llm_instance = MagicMock() + model = "gpt-4" + credentials = {} + prompt_messages = [] + model_parameters = {} + + cb.on_before_invoke( + llm_instance=llm_instance, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + ) + + ex = RuntimeError("API timeout") + cb.on_invoke_error( + llm_instance=llm_instance, + ex=ex, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + ) + + assert len(cb.before_invoke_calls) == 1 + assert len(cb.invoke_error_calls) == 1 + assert cb.invoke_error_calls[0]["ex"] is ex + assert len(cb.after_invoke_calls) == 0 + + def test_print_text_with_color_in_integration(self, capsys): + """verify print_text works correctly in a concrete instance.""" + cb = ConcreteCallback() + cb.print_text("SUCCESS", color="green", end="\n") + captured = capsys.readouterr() + assert "SUCCESS" in captured.out + assert "\n" in captured.out + + def test_print_text_no_color_in_integration(self, capsys): + cb = ConcreteCallback() + cb.print_text("plain output") + captured = capsys.readouterr() + assert captured.out == "plain output" diff --git a/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_logging_callback.py b/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_logging_callback.py new file mode 100644 index 0000000000..0c6c1fd191 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_logging_callback.py @@ -0,0 +1,700 @@ +""" +Comprehensive unit tests for core/model_runtime/callbacks/logging_callback.py + +Coverage targets: + - LoggingCallback.on_before_invoke (all branches: stop, tools, user, stream, + prompt_message.name, model_parameters) + - LoggingCallback.on_new_chunk (writes to stdout) + - LoggingCallback.on_after_invoke (all branches: tool_calls present / absent) + - LoggingCallback.on_invoke_error (logs exception via logger.exception) +""" + +from __future__ import annotations + +import json +from collections.abc import Sequence +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest + +from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _make_usage() -> LLMUsage: + """Return a minimal LLMUsage instance.""" + return LLMUsage( + prompt_tokens=10, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.01"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.002"), + completion_price=Decimal("0.04"), + total_tokens=30, + total_price=Decimal("0.05"), + currency="USD", + latency=0.5, + ) + + +def _make_llm_result( + content: str = "hello world", + tool_calls: list | None = None, + model: str = "gpt-4", + system_fingerprint: str | None = "fp-abc", +) -> LLMResult: + """Return an LLMResult with an AssistantPromptMessage.""" + assistant_msg = AssistantPromptMessage( + content=content, + tool_calls=tool_calls or [], + ) + return LLMResult( + model=model, + message=assistant_msg, + usage=_make_usage(), + system_fingerprint=system_fingerprint, + ) + + +def _make_chunk(content: str = "chunk-text") -> LLMResultChunk: + """Return a minimal LLMResultChunk.""" + return LLMResultChunk( + model="gpt-4", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=content), + ), + ) + + +def _make_user_prompt(content: str = "Hello!", name: str | None = None) -> UserPromptMessage: + return UserPromptMessage(content=content, name=name) + + +def _make_system_prompt(content: str = "You are helpful.") -> SystemPromptMessage: + return SystemPromptMessage(content=content) + + +def _make_tool(name: str = "my_tool") -> PromptMessageTool: + return PromptMessageTool(name=name, description="A tool", parameters={}) + + +def _make_tool_call( + call_id: str = "call-1", + func_name: str = "some_func", + arguments: str = '{"key": "value"}', +) -> AssistantPromptMessage.ToolCall: + return AssistantPromptMessage.ToolCall( + id=call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=func_name, arguments=arguments), + ) + + +# --------------------------------------------------------------------------- +# Fixture: shared LoggingCallback instance (no heavy state) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def cb() -> LoggingCallback: + return LoggingCallback() + + +@pytest.fixture +def llm_instance() -> MagicMock: + return MagicMock() + + +# =========================================================================== +# Tests for on_before_invoke +# =========================================================================== + + +class TestOnBeforeInvoke: + """Tests for LoggingCallback.on_before_invoke.""" + + def _invoke( + self, + cb: LoggingCallback, + llm_instance: MagicMock, + *, + model: str = "gpt-4", + credentials: dict | None = None, + prompt_messages: list | None = None, + model_parameters: dict | None = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, + stream: bool = True, + user: str | None = None, + ): + cb.on_before_invoke( + llm_instance=llm_instance, + model=model, + credentials=credentials or {}, + prompt_messages=prompt_messages or [], + model_parameters=model_parameters or {}, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + + def test_minimal_call_does_not_raise(self, cb: LoggingCallback, llm_instance: MagicMock): + """Calling with bare-minimum args should not raise.""" + self._invoke(cb, llm_instance) + + def test_model_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """The model name must appear in print_text calls.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, model="claude-3") + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "claude-3" in calls_text + + def test_model_parameters_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """Each key-value pair of model_parameters must be printed.""" + params = {"temperature": 0.7, "max_tokens": 512} + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, model_parameters=params) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "temperature" in calls_text + assert "0.7" in calls_text + assert "max_tokens" in calls_text + assert "512" in calls_text + + def test_empty_model_parameters(self, cb: LoggingCallback, llm_instance: MagicMock): + """Empty model_parameters dict should not raise.""" + self._invoke(cb, llm_instance, model_parameters={}) + + # ------------------------------------------------------------------ + # stop branch + # ------------------------------------------------------------------ + + def test_stop_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): + """stop words must appear in output when provided.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stop=["STOP", "END"]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "stop" in calls_text + + def test_stop_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): + """When stop=None the stop line must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stop=None) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "\tstop:" not in calls_text + + def test_stop_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock): + """When stop=[] (falsy) the stop line must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stop=[]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "\tstop:" not in calls_text + + # ------------------------------------------------------------------ + # tools branch + # ------------------------------------------------------------------ + + def test_tools_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): + """Tool names must appear in output when tools are provided.""" + tools = [_make_tool("search"), _make_tool("calculate")] + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, tools=tools) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "search" in calls_text + assert "calculate" in calls_text + + def test_tools_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): + """When tools=None the Tools section must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, tools=None) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Tools:" not in calls_text + + def test_tools_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock): + """When tools=[] (falsy) the Tools section must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, tools=[]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Tools:" not in calls_text + + # ------------------------------------------------------------------ + # user branch + # ------------------------------------------------------------------ + + def test_user_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): + """User string must appear in output when provided.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, user="alice") + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "alice" in calls_text + + def test_user_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): + """When user=None the User line must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, user=None) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "User:" not in calls_text + + # ------------------------------------------------------------------ + # stream branch + # ------------------------------------------------------------------ + + def test_stream_true_prints_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock): + """When stream=True the [on_llm_new_chunk] marker must be printed.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stream=True) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "[on_llm_new_chunk]" in calls_text + + def test_stream_false_no_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock): + """When stream=False the [on_llm_new_chunk] marker must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stream=False) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "[on_llm_new_chunk]" not in calls_text + + # ------------------------------------------------------------------ + # prompt_messages branch + # ------------------------------------------------------------------ + + def test_prompt_message_with_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """When a PromptMessage has a name it must be printed.""" + msg = _make_user_prompt("hi", name="bob") + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, prompt_messages=[msg]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "bob" in calls_text + + def test_prompt_message_without_name_skips_name_line(self, cb: LoggingCallback, llm_instance: MagicMock): + """When a PromptMessage has no name the name line must NOT appear.""" + msg = _make_user_prompt("hi", name=None) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, prompt_messages=[msg]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "\tname:" not in calls_text + + def test_prompt_message_role_and_content_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """Role and content of each PromptMessage must appear in output.""" + msg = _make_system_prompt("Be concise.") + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, prompt_messages=[msg]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "system" in calls_text + assert "Be concise." in calls_text + + def test_multiple_prompt_messages_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """All entries in prompt_messages are iterated and printed.""" + msgs = [ + _make_system_prompt("sys"), + _make_user_prompt("user msg"), + ] + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, prompt_messages=msgs) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "sys" in calls_text + assert "user msg" in calls_text + + # ------------------------------------------------------------------ + # Combination: everything provided + # ------------------------------------------------------------------ + + def test_all_optional_fields_combined(self, cb: LoggingCallback, llm_instance: MagicMock): + """Supply stop, tools, user, multiple params, named message – no exception.""" + msgs = [_make_user_prompt("question", name="alice")] + tools = [_make_tool("tool_a")] + with patch.object(cb, "print_text"): + self._invoke( + cb, + llm_instance, + model="gpt-3.5", + model_parameters={"temperature": 1.0, "top_p": 0.9}, + tools=tools, + stop=["DONE"], + stream=True, + user="alice", + prompt_messages=msgs, + ) + + +# =========================================================================== +# Tests for on_new_chunk +# =========================================================================== + + +class TestOnNewChunk: + """Tests for LoggingCallback.on_new_chunk.""" + + def test_chunk_content_written_to_stdout(self, cb: LoggingCallback, llm_instance: MagicMock): + """on_new_chunk must write the chunk's text content to sys.stdout.""" + chunk = _make_chunk("hello from LLM") + written = [] + + with patch("sys.stdout") as mock_stdout: + mock_stdout.write.side_effect = written.append + cb.on_new_chunk( + llm_instance=llm_instance, + chunk=chunk, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + mock_stdout.write.assert_called_once_with("hello from LLM") + mock_stdout.flush.assert_called_once() + + def test_chunk_content_empty_string(self, cb: LoggingCallback, llm_instance: MagicMock): + """Works correctly even when the chunk content is an empty string.""" + chunk = _make_chunk("") + with patch("sys.stdout") as mock_stdout: + cb.on_new_chunk( + llm_instance=llm_instance, + chunk=chunk, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + mock_stdout.write.assert_called_once_with("") + mock_stdout.flush.assert_called_once() + + def test_chunk_passes_all_optional_params(self, cb: LoggingCallback, llm_instance: MagicMock): + """All optional parameters are accepted without errors.""" + chunk = _make_chunk("data") + with patch("sys.stdout"): + cb.on_new_chunk( + llm_instance=llm_instance, + chunk=chunk, + model="gpt-4", + credentials={"key": "secret"}, + prompt_messages=[_make_user_prompt("q")], + model_parameters={"temperature": 0.5}, + tools=[_make_tool("t1")], + stop=["EOS"], + stream=True, + user="bob", + ) + + +# =========================================================================== +# Tests for on_after_invoke +# =========================================================================== + + +class TestOnAfterInvoke: + """Tests for LoggingCallback.on_after_invoke.""" + + def _invoke( + self, + cb: LoggingCallback, + llm_instance: MagicMock, + result: LLMResult, + **kwargs, + ): + cb.on_after_invoke( + llm_instance=llm_instance, + result=result, + model=result.model, + credentials={}, + prompt_messages=[], + model_parameters={}, + **kwargs, + ) + + def test_basic_result_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """After-invoke header, content, model, usage, fingerprint must be printed.""" + result = _make_llm_result() + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "[on_llm_after_invoke]" in calls_text + assert "hello world" in calls_text + assert "gpt-4" in calls_text + assert "fp-abc" in calls_text + + def test_no_tool_calls_skips_tool_call_block(self, cb: LoggingCallback, llm_instance: MagicMock): + """When there are no tool_calls the 'Tool calls:' block must NOT appear.""" + result = _make_llm_result(tool_calls=[]) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Tool calls:" not in calls_text + + def test_with_tool_calls_prints_all_fields(self, cb: LoggingCallback, llm_instance: MagicMock): + """When tool_calls exist their id, name, and JSON arguments must be printed.""" + tc = _make_tool_call( + call_id="call-xyz", + func_name="fetch_data", + arguments='{"url": "https://example.com"}', + ) + result = _make_llm_result(tool_calls=[tc]) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Tool calls:" in calls_text + assert "call-xyz" in calls_text + assert "fetch_data" in calls_text + # arguments should be JSON-dumped + assert "https://example.com" in calls_text + + def test_multiple_tool_calls_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """All tool calls in the list must be iterated.""" + tcs = [ + _make_tool_call("id-1", "func_a", '{"a": 1}'), + _make_tool_call("id-2", "func_b", '{"b": 2}'), + ] + result = _make_llm_result(tool_calls=tcs) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "id-1" in calls_text + assert "func_a" in calls_text + assert "id-2" in calls_text + assert "func_b" in calls_text + + def test_system_fingerprint_none_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """When system_fingerprint is None it should still be printed (as None).""" + result = _make_llm_result(system_fingerprint=None) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "System Fingerprint: None" in calls_text + + def test_usage_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """The usage object must appear in the printed output.""" + result = _make_llm_result() + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Usage:" in calls_text + + def test_tool_call_arguments_are_json_dumped(self, cb: LoggingCallback, llm_instance: MagicMock): + """Verify json.dumps is applied to the arguments field (a string).""" + raw_args = '{"x": 42}' + tc = _make_tool_call(arguments=raw_args) + result = _make_llm_result(tool_calls=[tc]) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + + # Check if any call to print_text included the expected (json-encoded) arguments + # json.dumps(raw_args) produces a string starting and ending with quotes + expected_substring = json.dumps(raw_args) + found = any(expected_substring in str(call.args[0]) for call in mock_print.call_args_list) + assert found, f"Expected {expected_substring} to be printed in one of the calls" + + def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock): + """All optional parameters should be accepted without error.""" + result = _make_llm_result() + cb.on_after_invoke( + llm_instance=llm_instance, + result=result, + model=result.model, + credentials={"key": "secret"}, + prompt_messages=[_make_user_prompt("q")], + model_parameters={"temperature": 0.9}, + tools=[_make_tool("t")], + stop=[""], + stream=False, + user="carol", + ) + + +# =========================================================================== +# Tests for on_invoke_error +# =========================================================================== + + +class TestOnInvokeError: + """Tests for LoggingCallback.on_invoke_error.""" + + def _invoke_error( + self, + cb: LoggingCallback, + llm_instance: MagicMock, + ex: Exception, + **kwargs, + ): + cb.on_invoke_error( + llm_instance=llm_instance, + ex=ex, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + **kwargs, + ) + + def test_prints_error_header(self, cb: LoggingCallback, llm_instance: MagicMock): + """The [on_llm_invoke_error] banner must be printed.""" + with patch.object(cb, "print_text") as mock_print: + with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger: + self._invoke_error(cb, llm_instance, RuntimeError("boom")) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "[on_llm_invoke_error]" in calls_text + + def test_exception_logged_via_logger_exception(self, cb: LoggingCallback, llm_instance: MagicMock): + """logger.exception must be called with the exception.""" + ex = ValueError("something went wrong") + with patch.object(cb, "print_text"): + with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger: + self._invoke_error(cb, llm_instance, ex) + mock_logger.exception.assert_called_once_with(ex) + + def test_exception_type_variety(self, cb: LoggingCallback, llm_instance: MagicMock): + """Works with any exception type (TypeError, IOError, etc.).""" + for exc_cls in (TypeError, IOError, KeyError, Exception): + ex = exc_cls("error") + with patch.object(cb, "print_text"): + with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger: + self._invoke_error(cb, llm_instance, ex) + mock_logger.exception.assert_called_once_with(ex) + + def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock): + """All optional parameters should be accepted without error.""" + ex = RuntimeError("fail") + with patch.object(cb, "print_text"): + with patch("dify_graph.model_runtime.callbacks.logging_callback.logger"): + cb.on_invoke_error( + llm_instance=llm_instance, + ex=ex, + model="gpt-4", + credentials={"key": "secret"}, + prompt_messages=[_make_user_prompt("q")], + model_parameters={"temperature": 0.7}, + tools=[_make_tool("t")], + stop=["STOP"], + stream=True, + user="dave", + ) + + +# =========================================================================== +# Tests for print_text (inherited from Callback, exercised through LoggingCallback) +# =========================================================================== + + +class TestPrintText: + """Verify that print_text from the Callback base class works correctly.""" + + def test_print_text_with_color(self, cb: LoggingCallback, capsys): + """print_text with a known colour should emit an ANSI escape sequence.""" + cb.print_text("hello", color="blue") + captured = capsys.readouterr() + assert "hello" in captured.out + # ANSI escape codes should be present + assert "\x1b[" in captured.out + + def test_print_text_without_color(self, cb: LoggingCallback, capsys): + """print_text without colour should print plain text.""" + cb.print_text("plain text") + captured = capsys.readouterr() + assert "plain text" in captured.out + + def test_print_text_all_colours(self, cb: LoggingCallback, capsys): + """Verify all supported colour keys don't raise.""" + for colour in ("blue", "yellow", "pink", "green", "red"): + cb.print_text("x", color=colour) + captured = capsys.readouterr() + # All outputs should contain 'x' (5 calls) + assert captured.out.count("x") >= 5 + + +# =========================================================================== +# Integration-style test: real print_text called (no mocking) +# =========================================================================== + + +class TestLoggingCallbackIntegration: + """Light integration tests – real print_text calls, just checking no exceptions.""" + + def test_on_before_invoke_full_run(self, capsys): + """Full on_before_invoke run with all optional fields – verifies real output.""" + cb = LoggingCallback() + llm = MagicMock() + msgs = [_make_user_prompt("Who are you?", name="tester")] + tools = [_make_tool("calculator")] + cb.on_before_invoke( + llm_instance=llm, + model="gpt-4-turbo", + credentials={"api_key": "sk-xxx"}, + prompt_messages=msgs, + model_parameters={"temperature": 0.8}, + tools=tools, + stop=["STOP"], + stream=True, + user="test_user", + ) + captured = capsys.readouterr() + assert "gpt-4-turbo" in captured.out + assert "calculator" in captured.out + assert "test_user" in captured.out + assert "STOP" in captured.out + assert "tester" in captured.out + + def test_on_new_chunk_full_run(self, capsys): + """Full on_new_chunk run – verifies real stdout write.""" + cb = LoggingCallback() + chunk = _make_chunk("streaming token") + cb.on_new_chunk( + llm_instance=MagicMock(), + chunk=chunk, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + captured = capsys.readouterr() + assert "streaming token" in captured.out + + def test_on_after_invoke_full_run_with_tool_calls(self, capsys): + """Full on_after_invoke run with tool calls – verifies real output.""" + cb = LoggingCallback() + tc = _make_tool_call("call-99", "do_thing", '{"n": 5}') + result = _make_llm_result(content="result content", tool_calls=[tc], system_fingerprint="fp-xyz") + cb.on_after_invoke( + llm_instance=MagicMock(), + result=result, + model=result.model, + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + captured = capsys.readouterr() + assert "result content" in captured.out + assert "call-99" in captured.out + assert "do_thing" in captured.out + assert "fp-xyz" in captured.out + + def test_on_invoke_error_full_run(self, capsys): + """Full on_invoke_error run – just verifies no exception is raised.""" + cb = LoggingCallback() + ex = RuntimeError("something bad happened") + # logger.exception writes to stderr; we just confirm it doesn't crash + cb.on_invoke_error( + llm_instance=MagicMock(), + ex=ex, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + captured = capsys.readouterr() + assert "[on_llm_invoke_error]" in captured.out diff --git a/api/tests/unit_tests/dify_graph/model_runtime/entities/test_common_entities.py b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_common_entities.py new file mode 100644 index 0000000000..db147fb0cd --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_common_entities.py @@ -0,0 +1,35 @@ +from dify_graph.model_runtime.entities.common_entities import I18nObject + + +class TestI18nObject: + def test_i18n_object_with_both_languages(self): + """ + Test I18nObject when both zh_Hans and en_US are provided. + """ + i18n = I18nObject(zh_Hans="你好", en_US="Hello") + assert i18n.zh_Hans == "你好" + assert i18n.en_US == "Hello" + + def test_i18n_object_fallback_to_en_us(self): + """ + Test I18nObject when zh_Hans is missing, it should fallback to en_US. + """ + i18n = I18nObject(en_US="Hello") + assert i18n.zh_Hans == "Hello" + assert i18n.en_US == "Hello" + + def test_i18n_object_with_none_zh_hans(self): + """ + Test I18nObject when zh_Hans is None, it should fallback to en_US. + """ + i18n = I18nObject(zh_Hans=None, en_US="Hello") + assert i18n.zh_Hans == "Hello" + assert i18n.en_US == "Hello" + + def test_i18n_object_with_empty_zh_hans(self): + """ + Test I18nObject when zh_Hans is an empty string, it should fallback to en_US. + """ + i18n = I18nObject(zh_Hans="", en_US="Hello") + assert i18n.zh_Hans == "Hello" + assert i18n.en_US == "Hello" diff --git a/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_llm_entities.py similarity index 100% rename from api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py rename to api/tests/unit_tests/dify_graph/model_runtime/entities/test_llm_entities.py diff --git a/api/tests/unit_tests/dify_graph/model_runtime/entities/test_message_entities.py b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_message_entities.py new file mode 100644 index 0000000000..a96a38f5cd --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_message_entities.py @@ -0,0 +1,210 @@ +import pytest + +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + AudioPromptMessageContent, + DocumentPromptMessageContent, + ImagePromptMessageContent, + PromptMessageContent, + PromptMessageContentType, + PromptMessageFunction, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, + VideoPromptMessageContent, +) + + +class TestPromptMessageRole: + def test_value_of(self): + assert PromptMessageRole.value_of("system") == PromptMessageRole.SYSTEM + assert PromptMessageRole.value_of("user") == PromptMessageRole.USER + assert PromptMessageRole.value_of("assistant") == PromptMessageRole.ASSISTANT + assert PromptMessageRole.value_of("tool") == PromptMessageRole.TOOL + + with pytest.raises(ValueError, match="invalid prompt message type value invalid"): + PromptMessageRole.value_of("invalid") + + +class TestPromptMessageEntities: + def test_prompt_message_tool(self): + tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"}) + assert tool.name == "test_tool" + assert tool.description == "test desc" + assert tool.parameters == {"foo": "bar"} + + def test_prompt_message_function(self): + tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"}) + func = PromptMessageFunction(function=tool) + assert func.type == "function" + assert func.function == tool + + +class TestPromptMessageContent: + def test_text_content(self): + content = TextPromptMessageContent(data="hello") + assert content.type == PromptMessageContentType.TEXT + assert content.data == "hello" + + def test_image_content(self): + content = ImagePromptMessageContent( + format="jpg", base64_data="abc", mime_type="image/jpeg", detail=ImagePromptMessageContent.DETAIL.HIGH + ) + assert content.type == PromptMessageContentType.IMAGE + assert content.detail == ImagePromptMessageContent.DETAIL.HIGH + assert content.data == "data:image/jpeg;base64,abc" + + def test_image_content_url(self): + content = ImagePromptMessageContent(format="jpg", url="https://example.com/image.jpg", mime_type="image/jpeg") + assert content.data == "https://example.com/image.jpg" + + def test_audio_content(self): + content = AudioPromptMessageContent(format="mp3", base64_data="abc", mime_type="audio/mpeg") + assert content.type == PromptMessageContentType.AUDIO + assert content.data == "data:audio/mpeg;base64,abc" + + def test_video_content(self): + content = VideoPromptMessageContent(format="mp4", base64_data="abc", mime_type="video/mp4") + assert content.type == PromptMessageContentType.VIDEO + assert content.data == "data:video/mp4;base64,abc" + + def test_document_content(self): + content = DocumentPromptMessageContent(format="pdf", base64_data="abc", mime_type="application/pdf") + assert content.type == PromptMessageContentType.DOCUMENT + assert content.data == "data:application/pdf;base64,abc" + + +class TestPromptMessages: + def test_user_prompt_message(self): + msg = UserPromptMessage(content="hello") + assert msg.role == PromptMessageRole.USER + assert msg.content == "hello" + assert msg.is_empty() is False + assert msg.get_text_content() == "hello" + + def test_user_prompt_message_complex_content(self): + content = [TextPromptMessageContent(data="hello "), TextPromptMessageContent(data="world")] + msg = UserPromptMessage(content=content) + assert msg.get_text_content() == "hello world" + + # Test validation from dict + msg2 = UserPromptMessage(content=[{"type": "text", "data": "hi"}]) + assert isinstance(msg2.content[0], TextPromptMessageContent) + assert msg2.content[0].data == "hi" + + def test_prompt_message_empty(self): + msg = UserPromptMessage(content=None) + assert msg.is_empty() is True + assert msg.get_text_content() == "" + + def test_assistant_prompt_message(self): + msg = AssistantPromptMessage(content="thinking...") + assert msg.role == PromptMessageRole.ASSISTANT + assert msg.is_empty() is False + + tool_call = AssistantPromptMessage.ToolCall( + id="call_1", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"), + ) + msg_with_tools = AssistantPromptMessage(content=None, tool_calls=[tool_call]) + assert msg_with_tools.is_empty() is False + assert msg_with_tools.role == PromptMessageRole.ASSISTANT + + def test_assistant_tool_call_id_transform(self): + tool_call = AssistantPromptMessage.ToolCall( + id=123, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"), + ) + assert tool_call.id == "123" + + def test_system_prompt_message(self): + msg = SystemPromptMessage(content="you are a bot") + assert msg.role == PromptMessageRole.SYSTEM + assert msg.content == "you are a bot" + + def test_tool_prompt_message(self): + # Case 1: Both content and tool_call_id are present + msg = ToolPromptMessage(content="result", tool_call_id="call_1") + assert msg.role == PromptMessageRole.TOOL + assert msg.tool_call_id == "call_1" + assert msg.is_empty() is False + + # Case 2: Content is present, but tool_call_id is empty + msg_content_only = ToolPromptMessage(content="result", tool_call_id="") + assert msg_content_only.is_empty() is False + + # Case 3: Content is None, but tool_call_id is present + msg_id_only = ToolPromptMessage(content=None, tool_call_id="call_1") + assert msg_id_only.is_empty() is False + + # Case 4: Both content and tool_call_id are empty + msg_empty = ToolPromptMessage(content=None, tool_call_id="") + assert msg_empty.is_empty() is True + + def test_prompt_message_validation_errors(self): + with pytest.raises(KeyError): + # Invalid content type in list + UserPromptMessage(content=[{"type": "invalid", "data": "foo"}]) + + with pytest.raises(ValueError, match="invalid prompt message"): + # Not a dict or PromptMessageContent + UserPromptMessage(content=[123]) + + def test_prompt_message_serialization(self): + # Case: content is None + assert UserPromptMessage(content=None).serialize_content(None) is None + + # Case: content is str + assert UserPromptMessage(content="hello").serialize_content("hello") == "hello" + + # Case: content is list of dict + content_list = [{"type": "text", "data": "hi"}] + msg = UserPromptMessage(content=content_list) + assert msg.serialize_content(msg.content) == [{"type": PromptMessageContentType.TEXT, "data": "hi"}] + + # Case: content is Sequence but not list (e.g. tuple) + # To hit line 204, we can call serialize_content manually or + # try to pass a type that pydantic doesn't convert to list in its internal state. + # Actually, let's just call it manually on the instance. + msg = UserPromptMessage(content="test") + content_tuple = (TextPromptMessageContent(data="hi"),) + assert msg.serialize_content(content_tuple) == content_tuple + + def test_prompt_message_mixed_content_validation(self): + # Test branch: isinstance(prompt, PromptMessageContent) + # but not (TextPromptMessageContent | MultiModalPromptMessageContent) + # Line 187: prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump()) + + # We need a PromptMessageContent that is NOT Text or MultiModal. + # But PromptMessageContentUnionTypes discriminator handles this usually. + # We can bypass high-level validation by passing the object directly in a list. + + class MockContent(PromptMessageContent): + type: PromptMessageContentType = PromptMessageContentType.TEXT + data: str + + mock_item = MockContent(data="test") + msg = UserPromptMessage(content=[mock_item]) + # It should hit line 187 and convert to TextPromptMessageContent + assert isinstance(msg.content[0], TextPromptMessageContent) + assert msg.content[0].data == "test" + + def test_prompt_message_get_text_content_branches(self): + # content is None + msg_none = UserPromptMessage(content=None) + assert msg_none.get_text_content() == "" + + # content is list but no text content + image = ImagePromptMessageContent(format="jpg", base64_data="abc", mime_type="image/jpeg") + msg_image = UserPromptMessage(content=[image]) + assert msg_image.get_text_content() == "" + + # content is list with mixed + text = TextPromptMessageContent(data="hello") + msg_mixed = UserPromptMessage(content=[text, image]) + assert msg_mixed.get_text_content() == "hello" diff --git a/api/tests/unit_tests/dify_graph/model_runtime/entities/test_model_entities.py b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_model_entities.py new file mode 100644 index 0000000000..3d03361f2a --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_model_entities.py @@ -0,0 +1,220 @@ +from decimal import Decimal + +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ModelUsage, + ParameterRule, + ParameterType, + PriceConfig, + PriceInfo, + PriceType, + ProviderModel, +) + + +class TestModelType: + def test_value_of(self): + assert ModelType.value_of("text-generation") == ModelType.LLM + assert ModelType.value_of(ModelType.LLM) == ModelType.LLM + assert ModelType.value_of("embeddings") == ModelType.TEXT_EMBEDDING + assert ModelType.value_of(ModelType.TEXT_EMBEDDING) == ModelType.TEXT_EMBEDDING + assert ModelType.value_of("reranking") == ModelType.RERANK + assert ModelType.value_of(ModelType.RERANK) == ModelType.RERANK + assert ModelType.value_of("speech2text") == ModelType.SPEECH2TEXT + assert ModelType.value_of(ModelType.SPEECH2TEXT) == ModelType.SPEECH2TEXT + assert ModelType.value_of("tts") == ModelType.TTS + assert ModelType.value_of(ModelType.TTS) == ModelType.TTS + assert ModelType.value_of(ModelType.MODERATION) == ModelType.MODERATION + + with pytest.raises(ValueError, match="invalid origin model type invalid"): + ModelType.value_of("invalid") + + def test_to_origin_model_type(self): + assert ModelType.LLM.to_origin_model_type() == "text-generation" + assert ModelType.TEXT_EMBEDDING.to_origin_model_type() == "embeddings" + assert ModelType.RERANK.to_origin_model_type() == "reranking" + assert ModelType.SPEECH2TEXT.to_origin_model_type() == "speech2text" + assert ModelType.TTS.to_origin_model_type() == "tts" + assert ModelType.MODERATION.to_origin_model_type() == "moderation" + + # Testing the else branch in to_origin_model_type + # Since it's a StrEnum, it's hard to get an invalid value here unless we mock or Force it. + # But if we look at the implementation: + # if self == self.LLM: ... elif ... else: raise ValueError + # We can try to create a "dummy" member if possible, or just skip it if we have 100% coverage otherwise. + # Actually, adding a new member to an enum at runtime is possible but messy. + # Let's see if we can trigger it. + + +class TestFetchFrom: + def test_values(self): + assert FetchFrom.PREDEFINED_MODEL == "predefined-model" + assert FetchFrom.CUSTOMIZABLE_MODEL == "customizable-model" + + +class TestModelFeature: + def test_values(self): + assert ModelFeature.TOOL_CALL == "tool-call" + assert ModelFeature.MULTI_TOOL_CALL == "multi-tool-call" + assert ModelFeature.AGENT_THOUGHT == "agent-thought" + assert ModelFeature.VISION == "vision" + assert ModelFeature.STREAM_TOOL_CALL == "stream-tool-call" + assert ModelFeature.DOCUMENT == "document" + assert ModelFeature.VIDEO == "video" + assert ModelFeature.AUDIO == "audio" + assert ModelFeature.STRUCTURED_OUTPUT == "structured-output" + + +class TestDefaultParameterName: + def test_value_of(self): + assert DefaultParameterName.value_of("temperature") == DefaultParameterName.TEMPERATURE + assert DefaultParameterName.value_of("top_p") == DefaultParameterName.TOP_P + + with pytest.raises(ValueError, match="invalid parameter name invalid"): + DefaultParameterName.value_of("invalid") + + +class TestParameterType: + def test_values(self): + assert ParameterType.FLOAT == "float" + assert ParameterType.INT == "int" + assert ParameterType.STRING == "string" + assert ParameterType.BOOLEAN == "boolean" + assert ParameterType.TEXT == "text" + + +class TestModelPropertyKey: + def test_values(self): + assert ModelPropertyKey.MODE == "mode" + assert ModelPropertyKey.CONTEXT_SIZE == "context_size" + + +class TestProviderModel: + def test_provider_model(self): + model = ProviderModel( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + ) + assert model.model == "gpt-4" + assert model.support_structure_output is False + + model_with_features = ProviderModel( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + features=[ModelFeature.STRUCTURED_OUTPUT], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + ) + assert model_with_features.support_structure_output is True + + +class TestParameterRule: + def test_parameter_rule(self): + rule = ParameterRule( + name="temperature", + label=I18nObject(en_US="Temperature"), + type=ParameterType.FLOAT, + default=0.7, + min=0.0, + max=1.0, + precision=2, + ) + assert rule.name == "temperature" + assert rule.default == 0.7 + + +class TestPriceConfig: + def test_price_config(self): + config = PriceConfig(input=Decimal("0.01"), output=Decimal("0.02"), unit=Decimal("0.001"), currency="USD") + assert config.input == Decimal("0.01") + assert config.output == Decimal("0.02") + + +class TestAIModelEntity: + def test_ai_model_entity_no_json_schema(self): + entity = AIModelEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + parameter_rules=[ + ParameterRule(name="temperature", label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT) + ], + ) + assert ModelFeature.STRUCTURED_OUTPUT not in (entity.features or []) + + def test_ai_model_entity_with_json_schema(self): + # Case: json_schema in parameter rules, features is None + entity = AIModelEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + parameter_rules=[ + ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) + ], + ) + assert ModelFeature.STRUCTURED_OUTPUT in entity.features + + def test_ai_model_entity_with_json_schema_and_features_empty(self): + # Case: json_schema in parameter rules, features is empty list + entity = AIModelEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + features=[], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + parameter_rules=[ + ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) + ], + ) + assert ModelFeature.STRUCTURED_OUTPUT in entity.features + + def test_ai_model_entity_with_json_schema_and_other_features(self): + # Case: json_schema in parameter rules, features has other things + entity = AIModelEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + features=[ModelFeature.VISION], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + parameter_rules=[ + ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) + ], + ) + assert ModelFeature.STRUCTURED_OUTPUT in entity.features + assert ModelFeature.VISION in entity.features + + +class TestModelUsage: + def test_model_usage(self): + usage = ModelUsage() + assert isinstance(usage, ModelUsage) + + +class TestPriceType: + def test_values(self): + assert PriceType.INPUT == "input" + assert PriceType.OUTPUT == "output" + + +class TestPriceInfo: + def test_price_info(self): + info = PriceInfo(unit_price=Decimal("0.01"), unit=Decimal(1000), total_amount=Decimal("0.05"), currency="USD") + assert info.total_amount == Decimal("0.05") diff --git a/api/tests/unit_tests/dify_graph/model_runtime/errors/test_invoke.py b/api/tests/unit_tests/dify_graph/model_runtime/errors/test_invoke.py new file mode 100644 index 0000000000..af62b2a84c --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/errors/test_invoke.py @@ -0,0 +1,63 @@ +from dify_graph.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +class TestInvokeErrors: + def test_invoke_error_with_description(self): + error = InvokeError("Custom description") + assert error.description == "Custom description" + assert str(error) == "Custom description" + assert isinstance(error, ValueError) + + def test_invoke_error_without_description(self): + error = InvokeError() + assert error.description is None + assert str(error) == "InvokeError" + + def test_invoke_connection_error(self): + # Now preserves class-level description + error = InvokeConnectionError() + assert error.description == "Connection Error" + assert str(error) == "Connection Error" + assert isinstance(error, InvokeError) + + # Test with explicit description + error_with_desc = InvokeConnectionError("Connection Error") + assert error_with_desc.description == "Connection Error" + assert str(error_with_desc) == "Connection Error" + + def test_invoke_server_unavailable_error(self): + error = InvokeServerUnavailableError() + assert error.description == "Server Unavailable Error" + assert str(error) == "Server Unavailable Error" + assert isinstance(error, InvokeError) + + def test_invoke_rate_limit_error(self): + error = InvokeRateLimitError() + assert error.description == "Rate Limit Error" + assert str(error) == "Rate Limit Error" + assert isinstance(error, InvokeError) + + def test_invoke_authorization_error(self): + error = InvokeAuthorizationError() + assert error.description == "Incorrect model credentials provided, please check and try again. " + assert str(error) == "Incorrect model credentials provided, please check and try again. " + assert isinstance(error, InvokeError) + + def test_invoke_bad_request_error(self): + error = InvokeBadRequestError() + assert error.description == "Bad Request Error" + assert str(error) == "Bad Request Error" + assert isinstance(error, InvokeError) + + def test_invoke_error_inheritance(self): + # Test that we can override the default description in subclasses + error = InvokeBadRequestError("Overridden Error") + assert error.description == "Overridden Error" + assert str(error) == "Overridden Error" diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_ai_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_ai_model.py new file mode 100644 index 0000000000..382dce876e --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_ai_model.py @@ -0,0 +1,336 @@ +import decimal +from unittest.mock import MagicMock, patch + +import pytest +from redis import RedisError + +from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceConfig, + PriceType, +) +from dify_graph.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel + + +class TestAIModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def ai_model(self, mock_plugin_model_provider): + return AIModel( + tenant_id="tenant_123", + model_type=ModelType.LLM, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_invoke_error_mapping(self, ai_model): + mapping = ai_model._invoke_error_mapping + assert InvokeConnectionError in mapping + assert InvokeServerUnavailableError in mapping + assert InvokeRateLimitError in mapping + assert InvokeAuthorizationError in mapping + assert InvokeBadRequestError in mapping + assert PluginDaemonInnerError in mapping + assert ValueError in mapping + + def test_transform_invoke_error(self, ai_model): + # Case: mapped error (InvokeAuthorizationError) + err = Exception("Original error") + with patch.object(AIModel, "_invoke_error_mapping", {InvokeAuthorizationError: [Exception]}): + transformed = ai_model._transform_invoke_error(err) + assert isinstance(transformed, InvokeAuthorizationError) + assert "Incorrect model credentials provided" in str(transformed.description) + + # Case: mapped error (InvokeError subclass) + with patch.object(AIModel, "_invoke_error_mapping", {InvokeRateLimitError("Rate limit"): [Exception]}): + transformed = ai_model._transform_invoke_error(err) + assert isinstance(transformed, InvokeError) + assert "[test_provider]" in transformed.description + + # Case: mapped error (not InvokeError) + class CustomNonInvokeError(Exception): + pass + + with patch.object(AIModel, "_invoke_error_mapping", {CustomNonInvokeError: [Exception]}): + transformed = ai_model._transform_invoke_error(err) + assert transformed == err + + # Case: unmapped error + unmapped_err = Exception("Unmapped") + transformed = ai_model._transform_invoke_error(unmapped_err) + assert isinstance(transformed, InvokeError) + assert "Error: Unmapped" in transformed.description + + def test_get_price(self, ai_model): + model_name = "test_model" + credentials = {"key": "value"} + + # Mock get_model_schema + mock_schema = MagicMock(spec=AIModelEntity) + mock_schema.pricing = PriceConfig( + input=decimal.Decimal("0.002"), + output=decimal.Decimal("0.004"), + unit=decimal.Decimal(1000), # 1000 tokens per unit + currency="USD", + ) + + with patch.object(AIModel, "get_model_schema", return_value=mock_schema): + # Test INPUT + price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 2000) + assert price_info.unit_price == decimal.Decimal("0.002") + + # Test OUTPUT + price_info = ai_model.get_price(model_name, credentials, PriceType.OUTPUT, 2000) + assert price_info.unit_price == decimal.Decimal("0.004") + + # Case: unit_price is None (returns zeroed PriceInfo) + mock_schema.pricing = None + with patch.object(AIModel, "get_model_schema", return_value=mock_schema): + price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 1000) + assert price_info.total_amount == decimal.Decimal("0.0") + + def test_get_price_no_price_config_error(self, ai_model): + model_name = "test_model" + + # We need it to be truthy at line 107 and 112 but falsy at line 127. + class ChangingPriceConfig: + def __init__(self): + self.input = decimal.Decimal("0.01") + self.unit = decimal.Decimal(1) + self.currency = "USD" + self.called = 0 + + def __bool__(self): + self.called += 1 + return self.called <= 2 + + mock_schema = MagicMock() + mock_schema.pricing = ChangingPriceConfig() + + with patch.object(AIModel, "get_model_schema", return_value=mock_schema): + with pytest.raises(ValueError) as excinfo: + ai_model.get_price(model_name, {}, PriceType.INPUT, 1000) + assert "Price config not found" in str(excinfo.value) + + def test_get_model_schema_cache_hit(self, ai_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + with patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis: + mock_redis.get.return_value = mock_schema.model_dump_json().encode() + + schema = ai_model.get_model_schema(model_name, credentials) + + assert schema.model == "test_model" + mock_redis.get.assert_called_once() + + def test_get_model_schema_cache_miss(self, ai_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.return_value = None + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = mock_schema + + schema = ai_model.get_model_schema(model_name, credentials) + + assert schema == mock_schema + mock_manager.get_model_schema.assert_called_once() + mock_redis.setex.assert_called_once() + + def test_get_model_schema_redis_error(self, ai_model): + model_name = "test_model" + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.side_effect = RedisError("Connection refused") + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = None + + schema = ai_model.get_model_schema(model_name, {}) + + assert schema is None + mock_manager.get_model_schema.assert_called_once() + + def test_get_model_schema_validation_error(self, ai_model): + model_name = "test_model" + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.return_value = b"invalid json" + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = None + + # This should trigger ValidationError at line 166 and go to delete() + schema = ai_model.get_model_schema(model_name, {}) + + assert schema is None + mock_redis.delete.assert_called() + + def test_get_model_schema_redis_delete_error(self, ai_model): + model_name = "test_model" + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.return_value = b'{"invalid": "schema"}' + mock_redis.delete.side_effect = RedisError("Delete failed") + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = None + + schema = ai_model.get_model_schema(model_name, {}) + + assert schema is None + mock_redis.delete.assert_called() + + def test_get_model_schema_redis_setex_error(self, ai_model): + model_name = "test_model" + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.return_value = None + mock_redis.setex.side_effect = RuntimeError("Setex failed") + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = mock_schema + + schema = ai_model.get_model_schema(model_name, {}) + + assert schema == mock_schema + mock_redis.setex.assert_called() + + def test_get_customizable_model_schema_from_credentials_template_mapping_value_error(self, ai_model): + model_name = "test_model" + + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[ + ParameterRule( + name="invalid", + use_template="invalid_template_name", + label=I18nObject(en_US="Invalid"), + type=ParameterType.FLOAT, + ) + ], + ) + + with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): + schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {}) + assert schema.parameter_rules[0].use_template == "invalid_template_name" + + def test_get_customizable_model_schema_from_credentials(self, ai_model): + model_name = "test_model" + + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[ + ParameterRule( + name="temp", use_template="temperature", label=I18nObject(en_US="Temp"), type=ParameterType.FLOAT + ), + ParameterRule( + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P"), + type=ParameterType.FLOAT, + help=I18nObject(en_US=""), + ), + ParameterRule( + name="max_tokens", + use_template="max_tokens", + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + help=I18nObject(en_US="", zh_Hans=""), + ), + ParameterRule(name="custom", label=I18nObject(en_US="Custom"), type=ParameterType.STRING), + ], + ) + + with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): + schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {}) + + assert schema.parameter_rules[0].max == 1.0 + assert schema.parameter_rules[1].help.en_US != "" + assert schema.parameter_rules[2].help.zh_Hans != "" + assert schema.parameter_rules[3].use_template is None + + def test_get_customizable_model_schema_from_credentials_none(self, ai_model): + with patch.object(AIModel, "get_customizable_model_schema", return_value=None): + schema = ai_model.get_customizable_model_schema_from_credentials("model", {}) + assert schema is None + + def test_get_customizable_model_schema_default(self, ai_model): + assert ai_model.get_customizable_model_schema("model", {}) is None + + def test_get_default_parameter_rule_variable_map(self, ai_model): + # Valid + res = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE) + assert res["default"] == 0.0 + + # Invalid + with pytest.raises(Exception) as excinfo: + ai_model._get_default_parameter_rule_variable_map("invalid_name") + assert "Invalid model parameter rule name" in str(excinfo.value) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_large_language_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_large_language_model.py new file mode 100644 index 0000000000..a692f8023a --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_large_language_model.py @@ -0,0 +1,476 @@ +import logging +from collections.abc import Generator, Iterator, Sequence +from dataclasses import dataclass, field +from datetime import datetime +from decimal import Decimal +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import pytest + +import dify_graph.model_runtime.model_providers.__base.large_language_model as llm_module + +# Access large_language_model members via llm_module to avoid partial import issues in CI +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.callbacks.base_callback import Callback +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import ModelType, PriceInfo +from dify_graph.model_runtime.model_providers.__base.large_language_model import _build_llm_result_from_chunks + + +def _usage(prompt_tokens: int = 1, completion_tokens: int = 2) -> LLMUsage: + return LLMUsage( + prompt_tokens=prompt_tokens, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal(1), + prompt_price=Decimal(prompt_tokens) * Decimal("0.001"), + completion_tokens=completion_tokens, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal(1), + completion_price=Decimal(completion_tokens) * Decimal("0.002"), + total_tokens=prompt_tokens + completion_tokens, + total_price=Decimal(prompt_tokens) * Decimal("0.001") + Decimal(completion_tokens) * Decimal("0.002"), + currency="USD", + latency=0.0, + ) + + +def _tool_call_delta( + *, + tool_call_id: str, + tool_type: str = "function", + function_name: str = "", + function_arguments: str = "", +) -> AssistantPromptMessage.ToolCall: + return AssistantPromptMessage.ToolCall( + id=tool_call_id, + type=tool_type, + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=function_name, arguments=function_arguments), + ) + + +def _chunk( + *, + model: str = "test-model", + content: str | list[Any] | None = None, + tool_calls: list[AssistantPromptMessage.ToolCall] | None = None, + usage: LLMUsage | None = None, + system_fingerprint: str | None = None, +) -> LLMResultChunk: + return LLMResultChunk( + model=model, + system_fingerprint=system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=content, tool_calls=tool_calls or []), + usage=usage, + ), + ) + + +@dataclass +class SpyCallback(Callback): + raise_error: bool = False + before: list[dict[str, Any]] = field(default_factory=list) + new_chunk: list[dict[str, Any]] = field(default_factory=list) + after: list[dict[str, Any]] = field(default_factory=list) + error: list[dict[str, Any]] = field(default_factory=list) + + def on_before_invoke(self, **kwargs: Any) -> None: # type: ignore[override] + self.before.append(kwargs) + + def on_new_chunk(self, **kwargs: Any) -> None: # type: ignore[override] + self.new_chunk.append(kwargs) + + def on_after_invoke(self, **kwargs: Any) -> None: # type: ignore[override] + self.after.append(kwargs) + + def on_invoke_error(self, **kwargs: Any) -> None: # type: ignore[override] + self.error.append(kwargs) + + +class _TestLLM(llm_module.LargeLanguageModel): + def get_price(self, model: str, credentials: dict, price_type: Any, tokens: int) -> PriceInfo: # type: ignore[override] + return PriceInfo( + unit_price=Decimal("0.01"), + unit=Decimal(1), + total_amount=Decimal(tokens) * Decimal("0.01"), + currency="USD", + ) + + def _transform_invoke_error(self, error: Exception) -> Exception: # type: ignore[override] + return RuntimeError(f"transformed: {error}") + + +@pytest.fixture +def llm() -> _TestLLM: + plugin_provider = PluginModelProviderEntity.model_construct( + id="provider-id", + created_at=datetime.now(), + updated_at=datetime.now(), + provider="provider", + tenant_id="tenant", + plugin_unique_identifier="plugin-uid", + plugin_id="plugin-id", + declaration=MagicMock(), + ) + return _TestLLM.model_construct( + tenant_id="tenant", + model_type=ModelType.LLM, + plugin_id="plugin-id", + provider_name="provider", + plugin_model_provider=plugin_provider, + started_at=1.0, + ) + + +def test_gen_tool_call_id_is_uuid_based(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="abc123")) + assert llm_module._gen_tool_call_id() == "chatcmpl-tool-abc123" + + +def test_run_callbacks_no_callbacks_noop() -> None: + invoked: list[int] = [] + llm_module._run_callbacks(None, event="x", invoke=lambda _: invoked.append(1)) + llm_module._run_callbacks([], event="x", invoke=lambda _: invoked.append(1)) + assert invoked == [] + + +def test_run_callbacks_swallows_error_when_raise_error_false(caplog: pytest.LogCaptureFixture) -> None: + class Boom: + raise_error = False + + caplog.set_level(logging.WARNING) + llm_module._run_callbacks( + [Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom")) + ) + assert any("Callback" in record.message and "failed with error" in record.message for record in caplog.records) + + +def test_run_callbacks_reraises_when_raise_error_true() -> None: + class Boom: + raise_error = True + + with pytest.raises(ValueError, match="boom"): + llm_module._run_callbacks( + [Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom")) + ) + + +def test_get_or_create_tool_call_empty_id_returns_last() -> None: + calls = [ + _tool_call_delta(tool_call_id="id1", function_name="a"), + _tool_call_delta(tool_call_id="id2", function_name="b"), + ] + assert llm_module._get_or_create_tool_call(calls, "") is calls[-1] + + +def test_get_or_create_tool_call_empty_id_without_existing_raises() -> None: + with pytest.raises(ValueError, match="tool_call_id is empty"): + llm_module._get_or_create_tool_call([], "") + + +def test_get_or_create_tool_call_creates_if_missing() -> None: + calls: list[AssistantPromptMessage.ToolCall] = [] + tool_call = llm_module._get_or_create_tool_call(calls, "new-id") + assert tool_call.id == "new-id" + assert tool_call.function.name == "" + assert tool_call.function.arguments == "" + assert calls == [tool_call] + + +def test_get_or_create_tool_call_returns_existing_when_found() -> None: + existing = _tool_call_delta(tool_call_id="same-id", function_name="fn", function_arguments="{}") + calls = [existing] + assert llm_module._get_or_create_tool_call(calls, "same-id") is existing + + +def test_merge_tool_call_delta_updates_fields_and_appends_arguments() -> None: + tool_call = _tool_call_delta(tool_call_id="id", tool_type="function", function_name="x", function_arguments="{") + delta = _tool_call_delta(tool_call_id="id2", tool_type="function", function_name="y", function_arguments="}") + llm_module._merge_tool_call_delta(tool_call, delta) + assert tool_call.id == "id2" + assert tool_call.type == "function" + assert tool_call.function.name == "y" + assert tool_call.function.arguments == "{}" + + +def test_increase_tool_call_generates_id_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="fixed")) + delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{") + existing: list[AssistantPromptMessage.ToolCall] = [] + llm_module._increase_tool_call([delta], existing) + assert len(existing) == 1 + assert existing[0].id == "chatcmpl-tool-fixed" + assert existing[0].function.name == "fn" + assert existing[0].function.arguments == "{" + + +def test_increase_tool_call_merges_incremental_arguments() -> None: + existing: list[AssistantPromptMessage.ToolCall] = [] + llm_module._increase_tool_call( + [_tool_call_delta(tool_call_id="id", function_name="fn", function_arguments="{")], existing + ) + llm_module._increase_tool_call( + [_tool_call_delta(tool_call_id="id", function_name="", function_arguments="}")], existing + ) + assert len(existing) == 1 + assert existing[0].function.name == "fn" + assert existing[0].function.arguments == "{}" + + +@pytest.mark.parametrize( + ("content", "expected_type"), + [ + ("hello", str), + ([TextPromptMessageContent(data="hello")], list), + ], +) +def test_build_llm_result_from_chunks_accumulates_and_raises_error( + content: str | list[TextPromptMessageContent], + expected_type: type, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="drain")) + caplog.set_level(logging.DEBUG) + + tool_delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{}") + first = _chunk(content=content, tool_calls=[tool_delta], usage=_usage(3, 4), system_fingerprint="fp1") + + def iter_with_error() -> Iterator[LLMResultChunk]: + yield first + raise RuntimeError("drain boom") + + with pytest.raises(RuntimeError, match="drain boom"): + _build_llm_result_from_chunks( + model="m", prompt_messages=[UserPromptMessage(content="u")], chunks=iter_with_error() + ) + + assert any("Error while consuming non-stream plugin chunk iterator" in record.message for record in caplog.records) + + +def test_build_llm_result_from_chunks_empty_iterator() -> None: + def empty() -> Iterator[LLMResultChunk]: + if False: # pragma: no cover + yield _chunk() + return + + result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=empty()) + assert result.message.content == [] + assert result.usage.total_tokens == 0 + assert result.system_fingerprint is None + + +def test_build_llm_result_from_chunks_accumulates_all_chunks() -> None: + chunks = iter([_chunk(content="first"), _chunk(content="second")]) + result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=chunks) + assert result.message.content == "firstsecond" + + +def test_invoke_llm_via_plugin_passes_list_converted_stop(monkeypatch: pytest.MonkeyPatch) -> None: + invoked: dict[str, Any] = {} + + class FakePluginModelClient: + def invoke_llm(self, **kwargs: Any) -> str: + invoked.update(kwargs) + return "ok" + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) + + prompt_messages: Sequence[PromptMessage] = (UserPromptMessage(content="hi"),) + result = llm_module._invoke_llm_via_plugin( + tenant_id="t", + user_id="u", + plugin_id="p", + provider="prov", + model="m", + credentials={"k": "v"}, + model_parameters={"temp": 1}, + prompt_messages=prompt_messages, + tools=None, + stop=("a", "b"), + stream=True, + ) + + assert result == "ok" + assert invoked["prompt_messages"] == list(prompt_messages) + assert invoked["stop"] == ["a", "b"] + + +def test_normalize_non_stream_plugin_result_passthrough_llmresult() -> None: + llm_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()) + assert ( + llm_module._normalize_non_stream_plugin_result(model="m", prompt_messages=[], result=llm_result) is llm_result + ) + + +def test_normalize_non_stream_plugin_result_builds_from_chunks() -> None: + chunks = iter([_chunk(content="hello", usage=_usage(1, 1))]) + result = llm_module._normalize_non_stream_plugin_result( + model="m", prompt_messages=[UserPromptMessage(content="u")], result=chunks + ) + assert isinstance(result, LLMResult) + assert result.message.content == "hello" + + +def test_invoke_non_stream_normalizes_and_sets_prompt_messages(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + plugin_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()) + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", + lambda **_: plugin_result, + ) + cb = SpyCallback() + prompt_messages = [UserPromptMessage(content="hi")] + result = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=False, callbacks=[cb]) + assert isinstance(result, LLMResult) + assert result.prompt_messages == prompt_messages + assert len(cb.before) == 1 + assert len(cb.after) == 1 + assert cb.after[0]["result"].prompt_messages == prompt_messages + + +def test_invoke_stream_wraps_generator_and_triggers_callbacks(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + plugin_chunks = iter( + [ + _chunk(model="m1", content="a"), + _chunk( + model="m2", content=[TextPromptMessageContent(data="b")], usage=_usage(2, 3), system_fingerprint="fp" + ), + _chunk(model="m3", content=None), + ] + ) + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", + lambda **_: plugin_chunks, + ) + + cb = SpyCallback() + prompt_messages = [UserPromptMessage(content="hi")] + gen = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=True, callbacks=[cb]) + + assert isinstance(gen, Generator) + chunks = list(gen) + assert len(chunks) == 3 + assert all(chunk.prompt_messages == prompt_messages for chunk in chunks) + assert len(cb.before) == 1 + assert len(cb.new_chunk) == 3 + assert len(cb.after) == 1 + final_result: LLMResult = cb.after[0]["result"] + assert final_result.model == "m3" + assert final_result.system_fingerprint == "fp" + assert isinstance(final_result.message.content, list) + assert [c.data for c in final_result.message.content] == ["a", "b"] + assert final_result.usage.total_tokens == 5 + + +def test_invoke_triggers_error_callbacks_and_raises_transformed(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + def boom(**_: Any) -> Any: + raise ValueError("plugin down") + + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", boom + ) + cb = SpyCallback() + with pytest.raises(RuntimeError, match="transformed: plugin down"): + llm.invoke( + model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False, callbacks=[cb] + ) + assert len(cb.error) == 1 + assert isinstance(cb.error[0]["ex"], ValueError) + + +def test_invoke_raises_not_implemented_for_unsupported_result_type( + llm: _TestLLM, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(llm_module, "_invoke_llm_via_plugin", lambda **_: "not-a-result") + monkeypatch.setattr(llm_module, "_normalize_non_stream_plugin_result", lambda **_: "not-a-result") + with pytest.raises(NotImplementedError, match="unsupported invoke result type"): + llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False) + + +def test_invoke_appends_logging_callback_in_debug(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + captured_callbacks: list[list[Callback]] = [] + + class FakeLoggingCallback(SpyCallback): + pass + + monkeypatch.setattr(llm_module, "LoggingCallback", FakeLoggingCallback) + monkeypatch.setattr(llm_module.dify_config, "DEBUG", True) + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", + lambda **_: LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()), + ) + + original_trigger = llm._trigger_before_invoke_callbacks + + def spy_trigger(*args: Any, **kwargs: Any) -> None: + captured_callbacks.append(list(kwargs["callbacks"])) + original_trigger(*args, **kwargs) + + monkeypatch.setattr(llm, "_trigger_before_invoke_callbacks", spy_trigger) + llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False) + assert any(isinstance(cb, FakeLoggingCallback) for cb in captured_callbacks[0]) + + +def test_get_num_tokens_returns_0_when_plugin_disabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False) + assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 0 + + +def test_get_num_tokens_uses_plugin_when_enabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", True) + + class FakePluginModelClient: + def get_llm_num_tokens(self, **kwargs: Any) -> int: + assert kwargs["tenant_id"] == "tenant" + assert kwargs["plugin_id"] == "plugin-id" + assert kwargs["provider"] == "provider" + assert kwargs["model_type"] == "llm" + return 42 + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) + assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 42 + + +def test_calc_response_usage_uses_prices_and_latency(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.time, "perf_counter", lambda: 4.5) + llm.started_at = 1.0 + usage = llm.calc_response_usage(model="m", credentials={}, prompt_tokens=10, completion_tokens=5) + assert usage.total_tokens == 15 + assert usage.total_price == Decimal("0.15") + assert usage.latency == 3.5 + + +def test_invoke_result_generator_raises_transformed_on_iteration_error(llm: _TestLLM) -> None: + def broken() -> Iterator[LLMResultChunk]: + yield _chunk(content="ok") + raise ValueError("chunk stream broken") + + gen = llm._invoke_result_generator( + model="m", + result=broken(), + credentials={}, + prompt_messages=[UserPromptMessage(content="u")], + model_parameters={}, + callbacks=[SpyCallback()], + ) + + with pytest.raises(RuntimeError, match="transformed: chunk stream broken"): + list(gen) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_moderation_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_moderation_model.py new file mode 100644 index 0000000000..6ccc44ceb8 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_moderation_model.py @@ -0,0 +1,90 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel + + +class TestModerationModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def moderation_model(self, mock_plugin_model_provider): + return ModerationModel( + tenant_id="tenant_123", + model_type=ModelType.MODERATION, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_model_type(self, moderation_model): + assert moderation_model.model_type == ModelType.MODERATION + + def test_invoke_success(self, moderation_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + text = "test text" + user = "user_123" + + with ( + patch("core.plugin.impl.model.PluginModelClient") as mock_client_class, + patch("time.perf_counter", return_value=1.0), + ): + mock_client = mock_client_class.return_value + mock_client.invoke_moderation.return_value = True + + result = moderation_model.invoke(model=model_name, credentials=credentials, text=text, user=user) + + assert result is True + assert moderation_model.started_at == 1.0 + mock_client.invoke_moderation.assert_called_once_with( + tenant_id="tenant_123", + user_id="user_123", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + text=text, + ) + + def test_invoke_success_no_user(self, moderation_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + text = "test text" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_moderation.return_value = False + + result = moderation_model.invoke(model=model_name, credentials=credentials, text=text) + + assert result is False + mock_client.invoke_moderation.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + text=text, + ) + + def test_invoke_exception(self, moderation_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + text = "test text" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_moderation.side_effect = Exception("Test error") + + with pytest.raises(InvokeError) as excinfo: + moderation_model.invoke(model=model_name, credentials=credentials, text=text) + + assert "[test_provider] Error: Test error" in str(excinfo.value.description) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_rerank_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_rerank_model.py new file mode 100644 index 0000000000..67828894b3 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_rerank_model.py @@ -0,0 +1,181 @@ +from datetime import datetime +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel + + +@pytest.fixture +def rerank_model() -> RerankModel: + plugin_provider = PluginModelProviderEntity.model_construct( + id="provider-id", + created_at=datetime.now(), + updated_at=datetime.now(), + provider="provider", + tenant_id="tenant", + plugin_unique_identifier="plugin-uid", + plugin_id="plugin-id", + declaration=MagicMock(), + ) + return RerankModel.model_construct( + tenant_id="tenant", + model_type=ModelType.RERANK, + plugin_id="plugin-id", + provider_name="provider", + plugin_model_provider=plugin_provider, + ) + + +def test_model_type_is_rerank_by_default() -> None: + plugin_provider = PluginModelProviderEntity.model_construct( + id="provider-id", + created_at=datetime.now(), + updated_at=datetime.now(), + provider="provider", + tenant_id="tenant", + plugin_unique_identifier="plugin-uid", + plugin_id="plugin-id", + declaration=MagicMock(), + ) + model = RerankModel( + tenant_id="tenant", + plugin_id="plugin-id", + provider_name="provider", + plugin_model_provider=plugin_provider, + ) + assert model.model_type == ModelType.RERANK + + +def test_invoke_calls_plugin_and_passes_args(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None: + expected = RerankResult(model="rerank", docs=[RerankDocument(index=0, text="a", score=0.5)]) + + class FakePluginModelClient: + def __init__(self) -> None: + self.invoke_rerank_called_with: dict[str, Any] | None = None + + def invoke_rerank(self, **kwargs: Any) -> RerankResult: + self.invoke_rerank_called_with = kwargs + return expected + + import core.plugin.impl.model as plugin_model_module + + fake_client = FakePluginModelClient() + monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client) + + result = rerank_model.invoke( + model="rerank", + credentials={"k": "v"}, + query="q", + docs=["d1", "d2"], + score_threshold=0.2, + top_n=10, + user="user-1", + ) + + assert result == expected + assert fake_client.invoke_rerank_called_with == { + "tenant_id": "tenant", + "user_id": "user-1", + "plugin_id": "plugin-id", + "provider": "provider", + "model": "rerank", + "credentials": {"k": "v"}, + "query": "q", + "docs": ["d1", "d2"], + "score_threshold": 0.2, + "top_n": 10, + } + + +def test_invoke_uses_unknown_user_when_not_provided(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None: + class FakePluginModelClient: + def __init__(self) -> None: + self.kwargs: dict[str, Any] | None = None + + def invoke_rerank(self, **kwargs: Any) -> RerankResult: + self.kwargs = kwargs + return RerankResult(model="m", docs=[]) + + import core.plugin.impl.model as plugin_model_module + + fake_client = FakePluginModelClient() + monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client) + + rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"]) + assert fake_client.kwargs is not None + assert fake_client.kwargs["user_id"] == "unknown" + + +def test_invoke_transforms_and_raises_on_plugin_error( + rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch +) -> None: + class FakePluginModelClient: + def invoke_rerank(self, **_: Any) -> RerankResult: + raise ValueError("plugin down") + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) + monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}")) + + with pytest.raises(RuntimeError, match="transformed: plugin down"): + rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"]) + + +def test_invoke_multimodal_calls_plugin_and_passes_args( + rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch +) -> None: + expected = RerankResult(model="mm", docs=[RerankDocument(index=0, text="x", score=0.9)]) + + class FakePluginModelClient: + def __init__(self) -> None: + self.invoke_multimodal_rerank_called_with: dict[str, Any] | None = None + + def invoke_multimodal_rerank(self, **kwargs: Any) -> RerankResult: + self.invoke_multimodal_rerank_called_with = kwargs + return expected + + import core.plugin.impl.model as plugin_model_module + + fake_client = FakePluginModelClient() + monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client) + + query = {"type": "text", "text": "q"} + docs = [{"type": "text", "text": "d1"}] + result = rerank_model.invoke_multimodal_rerank( + model="mm", + credentials={"k": "v"}, + query=query, + docs=docs, + score_threshold=None, + top_n=None, + user=None, + ) + + assert result == expected + assert fake_client.invoke_multimodal_rerank_called_with is not None + assert fake_client.invoke_multimodal_rerank_called_with["tenant_id"] == "tenant" + assert fake_client.invoke_multimodal_rerank_called_with["user_id"] == "unknown" + assert fake_client.invoke_multimodal_rerank_called_with["query"] == query + assert fake_client.invoke_multimodal_rerank_called_with["docs"] == docs + + +def test_invoke_multimodal_transforms_and_raises_on_plugin_error( + rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch +) -> None: + class FakePluginModelClient: + def invoke_multimodal_rerank(self, **_: Any) -> RerankResult: + raise ValueError("plugin down") + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) + monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}")) + + with pytest.raises(RuntimeError, match="transformed: plugin down"): + rerank_model.invoke_multimodal_rerank(model="m", credentials={}, query={"q": 1}, docs=[{"d": 1}]) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_speech2text_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_speech2text_model.py new file mode 100644 index 0000000000..f891718dc6 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_speech2text_model.py @@ -0,0 +1,87 @@ +from io import BytesIO +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel + + +class TestSpeech2TextModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def speech2text_model(self, mock_plugin_model_provider): + return Speech2TextModel( + tenant_id="tenant_123", + model_type=ModelType.SPEECH2TEXT, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_model_type(self, speech2text_model): + assert speech2text_model.model_type == ModelType.SPEECH2TEXT + + def test_invoke_success(self, speech2text_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + file = BytesIO(b"audio data") + user = "user_123" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_speech_to_text.return_value = "transcribed text" + + result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file, user=user) + + assert result == "transcribed text" + mock_client.invoke_speech_to_text.assert_called_once_with( + tenant_id="tenant_123", + user_id="user_123", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + file=file, + ) + + def test_invoke_success_no_user(self, speech2text_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + file = BytesIO(b"audio data") + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_speech_to_text.return_value = "transcribed text" + + result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file) + + assert result == "transcribed text" + mock_client.invoke_speech_to_text.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + file=file, + ) + + def test_invoke_exception(self, speech2text_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + file = BytesIO(b"audio data") + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_speech_to_text.side_effect = Exception("Test error") + + with pytest.raises(InvokeError) as excinfo: + speech2text_model.invoke(model=model_name, credentials=credentials, file=file) + + assert "[test_provider] Error: Test error" in str(excinfo.value.description) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_text_embedding_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_text_embedding_model.py new file mode 100644 index 0000000000..c8f0a2ad49 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_text_embedding_model.py @@ -0,0 +1,185 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.entities.embedding_type import EmbeddingInputType +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult +from dify_graph.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + + +class TestTextEmbeddingModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def text_embedding_model(self, mock_plugin_model_provider): + return TextEmbeddingModel( + tenant_id="tenant_123", + model_type=ModelType.TEXT_EMBEDDING, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_model_type(self, text_embedding_model): + assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING + + def test_invoke_with_texts(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + texts = ["hello", "world"] + user = "user_123" + expected_result = MagicMock(spec=EmbeddingResult) + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_text_embedding.return_value = expected_result + + result = text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts, user=user) + + assert result == expected_result + mock_client.invoke_text_embedding.assert_called_once_with( + tenant_id="tenant_123", + user_id="user_123", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + texts=texts, + input_type=EmbeddingInputType.DOCUMENT, + ) + + def test_invoke_with_multimodel_documents(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + multimodel_documents = [{"type": "text", "text": "hello"}] + expected_result = MagicMock(spec=EmbeddingResult) + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_multimodal_embedding.return_value = expected_result + + result = text_embedding_model.invoke( + model=model_name, credentials=credentials, multimodel_documents=multimodel_documents + ) + + assert result == expected_result + mock_client.invoke_multimodal_embedding.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + documents=multimodel_documents, + input_type=EmbeddingInputType.DOCUMENT, + ) + + def test_invoke_no_input(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + with pytest.raises(ValueError) as excinfo: + text_embedding_model.invoke(model=model_name, credentials=credentials) + + assert "No texts or files provided" in str(excinfo.value) + + def test_invoke_precedence(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + texts = ["hello"] + multimodel_documents = [{"type": "text", "text": "world"}] + expected_result = MagicMock(spec=EmbeddingResult) + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_text_embedding.return_value = expected_result + + result = text_embedding_model.invoke( + model=model_name, credentials=credentials, texts=texts, multimodel_documents=multimodel_documents + ) + + assert result == expected_result + mock_client.invoke_text_embedding.assert_called_once() + mock_client.invoke_multimodal_embedding.assert_not_called() + + def test_invoke_exception(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + texts = ["hello"] + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_text_embedding.side_effect = Exception("Test error") + + with pytest.raises(InvokeError) as excinfo: + text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts) + + assert "[test_provider] Error: Test error" in str(excinfo.value.description) + + def test_get_num_tokens(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + texts = ["hello", "world"] + expected_tokens = [1, 1] + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.get_text_embedding_num_tokens.return_value = expected_tokens + + result = text_embedding_model.get_num_tokens(model=model_name, credentials=credentials, texts=texts) + + assert result == expected_tokens + mock_client.get_text_embedding_num_tokens.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + texts=texts, + ) + + def test_get_context_size(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + # Test case 1: Context size in schema + mock_schema = MagicMock() + mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048} + + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_context_size(model_name, credentials) == 2048 + + # Test case 2: No schema + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): + assert text_embedding_model._get_context_size(model_name, credentials) == 1000 + + # Test case 3: Context size NOT in schema properties + mock_schema.model_properties = {} + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_context_size(model_name, credentials) == 1000 + + def test_get_max_chunks(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + # Test case 1: Max chunks in schema + mock_schema = MagicMock() + mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10} + + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_max_chunks(model_name, credentials) == 10 + + # Test case 2: No schema + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): + assert text_embedding_model._get_max_chunks(model_name, credentials) == 1 + + # Test case 3: Max chunks NOT in schema properties + mock_schema.model_properties = {} + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_max_chunks(model_name, credentials) == 1 diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_tts_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_tts_model.py new file mode 100644 index 0000000000..b1aca9baa3 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_tts_model.py @@ -0,0 +1,131 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel + + +class TestTTSModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def tts_model(self, mock_plugin_model_provider): + return TTSModel( + tenant_id="tenant_123", + model_type=ModelType.TTS, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_model_type(self, tts_model): + assert tts_model.model_type == ModelType.TTS + + def test_invoke_success(self, tts_model): + model_name = "test_model" + tenant_id = "ignored_tenant_id" + credentials = {"api_key": "abc"} + content_text = "Hello world" + voice = "alloy" + user = "user_123" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_tts.return_value = [b"audio_chunk"] + + result = tts_model.invoke( + model=model_name, + tenant_id=tenant_id, + credentials=credentials, + content_text=content_text, + voice=voice, + user=user, + ) + + assert list(result) == [b"audio_chunk"] + mock_client.invoke_tts.assert_called_once_with( + tenant_id="tenant_123", + user_id="user_123", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def test_invoke_success_no_user(self, tts_model): + model_name = "test_model" + tenant_id = "ignored_tenant_id" + credentials = {"api_key": "abc"} + content_text = "Hello world" + voice = "alloy" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_tts.return_value = [b"audio_chunk"] + + result = tts_model.invoke( + model=model_name, tenant_id=tenant_id, credentials=credentials, content_text=content_text, voice=voice + ) + + assert list(result) == [b"audio_chunk"] + mock_client.invoke_tts.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def test_invoke_exception(self, tts_model): + model_name = "test_model" + tenant_id = "ignored_tenant_id" + credentials = {"api_key": "abc"} + content_text = "Hello world" + voice = "alloy" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_tts.side_effect = Exception("Test error") + + with pytest.raises(InvokeError) as excinfo: + tts_model.invoke( + model=model_name, + tenant_id=tenant_id, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + assert "[test_provider] Error: Test error" in str(excinfo.value.description) + + def test_get_tts_model_voices(self, tts_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + language = "en-US" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.get_tts_model_voices.return_value = [{"name": "Voice1"}] + + result = tts_model.get_tts_model_voices(model=model_name, credentials=credentials, language=language) + + assert result == [{"name": "Voice1"}] + mock_client.get_tts_model_voices.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + language=language, + ) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py new file mode 100644 index 0000000000..dde6ea02b5 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py @@ -0,0 +1,96 @@ +from unittest.mock import MagicMock, patch + +import dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer as gpt2_tokenizer_module +from dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer + + +class TestGPT2Tokenizer: + def setup_method(self): + # Reset the global tokenizer before each test to ensure we test initialization + gpt2_tokenizer_module._tokenizer = None + + def test_get_encoder_tiktoken(self): + """ + Test that get_encoder successfully uses tiktoken when available. + """ + mock_encoding = MagicMock() + # Mock tiktoken to be sure it's used + with patch("tiktoken.get_encoding", return_value=mock_encoding) as mock_get_encoding: + encoder = GPT2Tokenizer.get_encoder() + assert encoder == mock_encoding + mock_get_encoding.assert_called_once_with("gpt2") + + # Verify singleton behavior within the same test + encoder2 = GPT2Tokenizer.get_encoder() + assert encoder2 is encoder + assert mock_get_encoding.call_count == 1 + + def test_get_encoder_tiktoken_fallback(self): + """ + Test that get_encoder falls back to transformers when tiktoken fails. + """ + # patch tiktoken.get_encoding to raise an exception + with patch("tiktoken.get_encoding", side_effect=Exception("Tiktoken failure")): + # patch transformers.GPT2Tokenizer + with patch("transformers.GPT2Tokenizer.from_pretrained") as mock_from_pretrained: + mock_transformer_tokenizer = MagicMock() + mock_from_pretrained.return_value = mock_transformer_tokenizer + + with patch( + "dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer.logger" + ) as mock_logger: + encoder = GPT2Tokenizer.get_encoder() + + assert encoder == mock_transformer_tokenizer + mock_from_pretrained.assert_called_once() + mock_logger.info.assert_called_once_with("Fallback to Transformers' GPT-2 tokenizer from tiktoken") + + def test_get_num_tokens(self): + """ + Test get_num_tokens returns the correct count. + """ + mock_encoder = MagicMock() + mock_encoder.encode.return_value = [1, 2, 3, 4, 5] + + with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder): + tokens_count = GPT2Tokenizer.get_num_tokens("test text") + assert tokens_count == 5 + mock_encoder.encode.assert_called_once_with("test text") + + def test_get_num_tokens_by_gpt2_direct(self): + """ + Test _get_num_tokens_by_gpt2 directly. + """ + mock_encoder = MagicMock() + mock_encoder.encode.return_value = [1, 2] + + with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder): + tokens_count = GPT2Tokenizer._get_num_tokens_by_gpt2("hello") + assert tokens_count == 2 + mock_encoder.encode.assert_called_once_with("hello") + + def test_get_encoder_already_initialized(self): + """ + Test that if _tokenizer is already set, it returns it immediately. + """ + mock_existing_tokenizer = MagicMock() + gpt2_tokenizer_module._tokenizer = mock_existing_tokenizer + + # Tiktoken should not be called if already initialized + with patch("tiktoken.get_encoding") as mock_get_encoding: + encoder = GPT2Tokenizer.get_encoder() + assert encoder == mock_existing_tokenizer + mock_get_encoding.assert_not_called() + + def test_get_encoder_thread_safety(self): + """ + Simple test to ensure the lock is used. + """ + mock_encoding = MagicMock() + with patch("tiktoken.get_encoding", return_value=mock_encoding): + # We patch the lock in the module + with patch("dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer._lock") as mock_lock: + encoder = GPT2Tokenizer.get_encoder() + assert encoder == mock_encoding + mock_lock.__enter__.assert_called_once() + mock_lock.__exit__.assert_called_once() diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/test_model_provider_factory.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/test_model_provider_factory.py new file mode 100644 index 0000000000..1ad0210375 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/test_model_provider_factory.py @@ -0,0 +1,522 @@ +import logging +from datetime import datetime +from threading import Lock +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from redis import RedisError + +import contexts +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, +) +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + + +def _provider_entity( + *, + provider: str, + supported_model_types: list[ModelType] | None = None, + models: list[AIModelEntity] | None = None, + icon_small: I18nObject | None = None, + icon_small_dark: I18nObject | None = None, +) -> ProviderEntity: + return ProviderEntity( + provider=provider, + label=I18nObject(en_US=provider), + supported_model_types=supported_model_types or [ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=models or [], + icon_small=icon_small, + icon_small_dark=icon_small_dark, + ) + + +def _plugin_provider( + *, plugin_id: str, declaration: ProviderEntity, provider: str = "provider" +) -> PluginModelProviderEntity: + return PluginModelProviderEntity.model_construct( + id=f"{plugin_id}-id", + created_at=datetime.now(), + updated_at=datetime.now(), + provider=provider, + tenant_id="tenant", + plugin_unique_identifier=f"{plugin_id}-uid", + plugin_id=plugin_id, + declaration=declaration, + ) + + +@pytest.fixture(autouse=True) +def _reset_plugin_model_provider_context() -> None: + contexts.plugin_model_providers_lock.set(Lock()) + contexts.plugin_model_providers.set(None) + + +@pytest.fixture +def fake_plugin_manager(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + manager = MagicMock() + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: manager) + return manager + + +@pytest.fixture +def factory(fake_plugin_manager: MagicMock) -> ModelProviderFactory: + return ModelProviderFactory(tenant_id="tenant") + + +def test_get_plugin_model_providers_initializes_context_on_lookup_error( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch +) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + original_get = contexts.plugin_model_providers.get + calls = {"n": 0} + + def flaky_get() -> Any: + calls["n"] += 1 + if calls["n"] == 1: + raise LookupError + return original_get() + + monkeypatch.setattr(contexts.plugin_model_providers, "get", flaky_get) + + providers = factory.get_plugin_model_providers() + assert len(providers) == 1 + assert providers[0].declaration.provider == "langgenius/openai/openai" + + +def test_get_plugin_model_providers_caches_and_does_not_refetch( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + first = factory.get_plugin_model_providers() + second = factory.get_plugin_model_providers() + + assert first is second + fake_plugin_manager.fetch_model_providers.assert_called_once_with("tenant") + + +def test_get_providers_returns_declarations(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None: + d1 = _provider_entity(provider="openai") + d2 = _provider_entity(provider="anthropic") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=d1), + _plugin_provider(plugin_id="langgenius/anthropic", declaration=d2), + ] + + providers = factory.get_providers() + assert [p.provider for p in providers] == ["langgenius/openai/openai", "langgenius/anthropic/anthropic"] + + +def test_get_plugin_model_provider_converts_short_provider_id( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + provider = factory.get_plugin_model_provider("openai") + assert provider.declaration.provider == "langgenius/openai/openai" + + +def test_get_plugin_model_provider_raises_on_invalid_provider( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + with pytest.raises(ValueError, match="Invalid provider"): + factory.get_plugin_model_provider("langgenius/unknown/unknown") + + +def test_get_provider_schema_returns_declaration(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + schema = factory.get_provider_schema("openai") + assert schema.provider == "langgenius/openai/openai" + + +def test_provider_credentials_validate_errors_when_schema_missing( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + schema = _provider_entity(provider="openai") + schema.provider_credential_schema = None + monkeypatch.setattr( + factory, + "get_plugin_model_provider", + lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema), + ) + + with pytest.raises(ValueError, match="does not have provider_credential_schema"): + factory.provider_credentials_validate(provider="openai", credentials={"x": "y"}) + + +def test_provider_credentials_validate_filters_and_calls_plugin_validation( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch +) -> None: + schema = _provider_entity(provider="openai") + schema.provider_credential_schema = MagicMock() + plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema) + monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider) + + fake_validator = MagicMock() + fake_validator.validate_and_filter.return_value = {"filtered": True} + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.model_provider_factory.ProviderCredentialSchemaValidator", + lambda _: fake_validator, + ) + + filtered = factory.provider_credentials_validate(provider="openai", credentials={"raw": True}) + assert filtered == {"filtered": True} + fake_plugin_manager.validate_provider_credentials.assert_called_once() + kwargs = fake_plugin_manager.validate_provider_credentials.call_args.kwargs + assert kwargs["plugin_id"] == "langgenius/openai" + assert kwargs["provider"] == "provider" + assert kwargs["credentials"] == {"filtered": True} + + +def test_model_credentials_validate_errors_when_schema_missing( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + schema = _provider_entity(provider="openai") + schema.model_credential_schema = None + monkeypatch.setattr( + factory, + "get_plugin_model_provider", + lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema), + ) + + with pytest.raises(ValueError, match="does not have model_credential_schema"): + factory.model_credentials_validate( + provider="openai", model_type=ModelType.LLM, model="m", credentials={"x": "y"} + ) + + +def test_model_credentials_validate_filters_and_calls_plugin_validation( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch +) -> None: + schema = _provider_entity(provider="openai") + schema.model_credential_schema = MagicMock() + plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema) + monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider) + + fake_validator = MagicMock() + fake_validator.validate_and_filter.return_value = {"filtered": True} + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.model_provider_factory.ModelCredentialSchemaValidator", + lambda *_: fake_validator, + ) + + filtered = factory.model_credentials_validate( + provider="openai", model_type=ModelType.TEXT_EMBEDDING, model="m", credentials={"raw": True} + ) + assert filtered == {"filtered": True} + kwargs = fake_plugin_manager.validate_model_credentials.call_args.kwargs + assert kwargs["plugin_id"] == "langgenius/openai" + assert kwargs["provider"] == "provider" + assert kwargs["model_type"] == "text-embedding" + assert kwargs["model"] == "m" + assert kwargs["credentials"] == {"filtered": True} + + +def test_get_model_schema_cache_hit(factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch) -> None: + model_schema = AIModelEntity( + model="m", + label=I18nObject(en_US="m"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov")) + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.return_value = model_schema.model_dump_json().encode() + assert ( + factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials={"k": "v"}) + == model_schema + ) + + +def test_get_model_schema_cache_invalid_json_deletes_key( + factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture +) -> None: + caplog.set_level(logging.WARNING) + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.return_value = b'{"model":"m"}' + factory.plugin_model_manager.get_model_schema.return_value = None + factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] + assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None + assert mock_redis.delete.called + assert any("Failed to validate cached plugin model schema" in r.message for r in caplog.records) + + +def test_get_model_schema_cache_delete_redis_error_is_logged( + factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture +) -> None: + caplog.set_level(logging.WARNING) + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.return_value = b'{"model":"m"}' + mock_redis.delete.side_effect = RedisError("nope") + factory.plugin_model_manager.get_model_schema.return_value = None + factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] + factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) + assert any("Failed to delete invalid plugin model schema cache" in r.message for r in caplog.records) + + +def test_get_model_schema_redis_get_error_falls_back_to_plugin( + factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture +) -> None: + caplog.set_level(logging.WARNING) + factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] + factory.plugin_model_manager.get_model_schema.return_value = None + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.side_effect = RedisError("down") + assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None + assert any("Failed to read plugin model schema cache" in r.message for r in caplog.records) + + +def test_get_model_schema_cache_miss_sets_cache_and_handles_setex_error( + factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture +) -> None: + caplog.set_level(logging.WARNING) + factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] + + model_schema = AIModelEntity( + model="m", + label=I18nObject(en_US="m"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + factory.plugin_model_manager.get_model_schema.return_value = model_schema + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_redis.setex.side_effect = RedisError("nope") + assert ( + factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) + == model_schema + ) + assert any("Failed to write plugin model schema cache" in r.message for r in caplog.records) + + +@pytest.mark.parametrize( + ("model_type", "expected_class"), + [ + (ModelType.LLM, "LargeLanguageModel"), + (ModelType.TEXT_EMBEDDING, "TextEmbeddingModel"), + (ModelType.RERANK, "RerankModel"), + (ModelType.SPEECH2TEXT, "Speech2TextModel"), + (ModelType.MODERATION, "ModerationModel"), + (ModelType.TTS, "TTSModel"), + ], +) +def test_get_model_type_instance_dispatches_by_type( + factory: ModelProviderFactory, model_type: ModelType, expected_class: str, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov")) + monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity)) + + sentinel = object() + monkeypatch.setattr( + f"dify_graph.model_runtime.model_providers.model_provider_factory.{expected_class}", + MagicMock(model_validate=lambda _: sentinel), + ) + + assert factory.get_model_type_instance("langgenius/openai/openai", model_type) is sentinel + + +def test_get_model_type_instance_raises_on_unsupported( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov")) + monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity)) + + class UnknownModelType: + pass + + with pytest.raises(ValueError, match="Unsupported model type"): + factory.get_model_type_instance("langgenius/openai/openai", UnknownModelType()) # type: ignore[arg-type] + + +def test_get_models_filters_by_provider_and_model_type( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + llm = AIModelEntity( + model="m1", + label=I18nObject(en_US="m1"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + embed = AIModelEntity( + model="e1", + label=I18nObject(en_US="e1"), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + openai = _provider_entity( + provider="openai", supported_model_types=[ModelType.LLM, ModelType.TEXT_EMBEDDING], models=[llm, embed] + ) + anthropic = _provider_entity(provider="anthropic", supported_model_types=[ModelType.LLM], models=[llm]) + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=openai), + _plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic), + ] + + # ModelType filter picks only matching models + providers = factory.get_models(model_type=ModelType.TEXT_EMBEDDING) + assert len(providers) == 1 + assert providers[0].provider == "langgenius/openai/openai" + assert [m.model for m in providers[0].models] == ["e1"] + + # Provider filter excludes others + providers = factory.get_models(provider="langgenius/anthropic/anthropic", model_type=ModelType.LLM) + assert len(providers) == 1 + assert providers[0].provider == "langgenius/anthropic/anthropic" + + +def test_get_models_provider_filter_skips_non_matching( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + openai = _provider_entity(provider="openai") + anthropic = _provider_entity(provider="anthropic") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=openai), + _plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic), + ] + + providers = factory.get_models(provider="langgenius/not-exist/not-exist", model_type=ModelType.LLM) + assert providers == [] + + +def test_get_provider_icon_fetches_asset_and_returns_mime_type( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity( + provider="langgenius/openai/openai", + icon_small=I18nObject(en_US="icon.png", zh_Hans="icon-zh.png"), + icon_small_dark=I18nObject(en_US="dark.svg", zh_Hans="dark-zh.svg"), + ) + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + + class FakePluginAssetManager: + def fetch_asset(self, tenant_id: str, id: str) -> bytes: + assert tenant_id == "tenant" + return f"bytes:{id}".encode() + + import core.plugin.impl.asset as asset_module + + monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager) + + data, mime = factory.get_provider_icon("openai", "icon_small", "en_US") + assert data == b"bytes:icon.png" + assert mime == "image/png" + + data, mime = factory.get_provider_icon("openai", "icon_small_dark", "zh_Hans") + assert data == b"bytes:dark-zh.svg" + assert mime == "image/svg+xml" + + +def test_get_provider_icon_uses_zh_hans_for_small_and_en_us_for_dark( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity( + provider="langgenius/openai/openai", + icon_small=I18nObject(en_US="icon-en.png", zh_Hans="icon-zh.png"), + icon_small_dark=I18nObject(en_US="dark-en.svg", zh_Hans="dark-zh.svg"), + ) + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + + class FakePluginAssetManager: + def fetch_asset(self, tenant_id: str, id: str) -> bytes: + return id.encode() + + import core.plugin.impl.asset as asset_module + + monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager) + + data, _ = factory.get_provider_icon("openai", "icon_small", "zh_Hans") + assert data == b"icon-zh.png" + + data, _ = factory.get_provider_icon("openai", "icon_small_dark", "en_US") + assert data == b"dark-en.svg" + + +def test_get_provider_icon_raises_for_missing_icons( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity(provider="langgenius/openai/openai") + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + + with pytest.raises(ValueError, match="does not have small icon"): + factory.get_provider_icon("openai", "icon_small", "en_US") + + with pytest.raises(ValueError, match="does not have small dark icon"): + factory.get_provider_icon("openai", "icon_small_dark", "en_US") + + +def test_get_provider_icon_raises_for_unsupported_icon_type( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity( + provider="langgenius/openai/openai", + icon_small=I18nObject(en_US="", zh_Hans=""), + ) + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + with pytest.raises(ValueError, match="Unsupported icon type"): + factory.get_provider_icon("openai", "nope", "en_US") + + +def test_get_provider_icon_raises_when_file_name_missing( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity( + provider="langgenius/openai/openai", + icon_small=I18nObject(en_US="", zh_Hans=""), + ) + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + with pytest.raises(ValueError, match="does not have icon"): + factory.get_provider_icon("openai", "icon_small", "en_US") + + +def test_get_plugin_id_and_provider_name_from_provider_handles_google_special_case( + factory: ModelProviderFactory, +) -> None: + plugin_id, provider_name = factory.get_plugin_id_and_provider_name_from_provider("google") + assert plugin_id == "langgenius/gemini" + assert provider_name == "google" diff --git a/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_common_validator.py b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_common_validator.py new file mode 100644 index 0000000000..6d52457c8c --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_common_validator.py @@ -0,0 +1,201 @@ +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FormOption, + FormShowOnObject, + FormType, +) +from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator + + +class TestCommonValidator: + def test_validate_credential_form_schema_required_missing(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True + ) + with pytest.raises(ValueError, match="Variable api_key is required"): + validator._validate_credential_form_schema(schema, {}) + + def test_validate_credential_form_schema_not_required_missing_with_default(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.TEXT_INPUT, + required=False, + default="default_value", + ) + assert validator._validate_credential_form_schema(schema, {}) == "default_value" + + def test_validate_credential_form_schema_not_required_missing_no_default(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=False + ) + assert validator._validate_credential_form_schema(schema, {}) is None + + def test_validate_credential_form_schema_max_length_exceeded(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, max_length=5 + ) + with pytest.raises(ValueError, match="Variable api_key length should not be greater than 5"): + validator._validate_credential_form_schema(schema, {"api_key": "123456"}) + + def test_validate_credential_form_schema_not_string(self): + validator = CommonValidator() + schema = CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT) + with pytest.raises(ValueError, match="Variable api_key should be string"): + validator._validate_credential_form_schema(schema, {"api_key": 123}) + + def test_validate_credential_form_schema_select_invalid_option(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="mode", + label=I18nObject(en_US="Mode"), + type=FormType.SELECT, + options=[ + FormOption(label=I18nObject(en_US="Fast"), value="fast"), + FormOption(label=I18nObject(en_US="Slow"), value="slow"), + ], + ) + with pytest.raises(ValueError, match="Variable mode is not in options"): + validator._validate_credential_form_schema(schema, {"mode": "medium"}) + + def test_validate_credential_form_schema_select_valid_option(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="mode", + label=I18nObject(en_US="Mode"), + type=FormType.SELECT, + options=[ + FormOption(label=I18nObject(en_US="Fast"), value="fast"), + FormOption(label=I18nObject(en_US="Slow"), value="slow"), + ], + ) + assert validator._validate_credential_form_schema(schema, {"mode": "fast"}) == "fast" + + def test_validate_credential_form_schema_switch_invalid(self): + validator = CommonValidator() + schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH) + with pytest.raises(ValueError, match="Variable enabled should be true or false"): + validator._validate_credential_form_schema(schema, {"enabled": "maybe"}) + + def test_validate_credential_form_schema_switch_valid(self): + validator = CommonValidator() + schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH) + assert validator._validate_credential_form_schema(schema, {"enabled": "true"}) is True + assert validator._validate_credential_form_schema(schema, {"enabled": "FALSE"}) is False + + def test_validate_and_filter_credential_form_schemas_with_show_on(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema( + variable="auth_type", + label=I18nObject(en_US="Auth Type"), + type=FormType.SELECT, + options=[ + FormOption(label=I18nObject(en_US="API Key"), value="api_key"), + FormOption(label=I18nObject(en_US="OAuth"), value="oauth"), + ], + ), + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="auth_type", value="api_key")], + ), + CredentialFormSchema( + variable="client_id", + label=I18nObject(en_US="Client ID"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="auth_type", value="oauth")], + ), + ] + + # Case 1: auth_type = api_key + credentials = {"auth_type": "api_key", "api_key": "my_secret"} + result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) + assert "auth_type" in result + assert "api_key" in result + assert "client_id" not in result + assert result["api_key"] == "my_secret" + + # Case 2: auth_type = oauth + credentials = {"auth_type": "oauth", "client_id": "my_client"} + result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) + # Note: 'auth_type' contains 'oauth'. 'result' contains keys that pass validation. + # Since 'oauth' is not an empty string, it is in result. + assert "auth_type" in result + assert "api_key" not in result + assert "client_id" in result + assert result["client_id"] == "my_client" + + def test_validate_and_filter_show_on_missing_variable(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="auth_type", value="api_key")], + ) + ] + # auth_type is missing in credentials, so api_key should be filtered out + result = validator._validate_and_filter_credential_form_schemas(schemas, {}) + assert result == {} + + def test_validate_and_filter_show_on_mismatch_value(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="auth_type", value="api_key")], + ) + ] + # auth_type is oauth, which doesn't match show_on + result = validator._validate_and_filter_credential_form_schemas(schemas, {"auth_type": "oauth"}) + assert result == {} + + def test_validate_and_filter_multiple_show_on(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema( + variable="target", + label=I18nObject(en_US="Target"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="v1", value="a"), FormShowOnObject(variable="v2", value="b")], + ) + ] + # Both match + assert "target" in validator._validate_and_filter_credential_form_schemas( + schemas, {"v1": "a", "v2": "b", "target": "val"} + ) + # One mismatch + assert "target" not in validator._validate_and_filter_credential_form_schemas( + schemas, {"v1": "a", "v2": "c", "target": "val"} + ) + # One missing + assert "target" not in validator._validate_and_filter_credential_form_schemas( + schemas, {"v1": "a", "target": "val"} + ) + + def test_validate_and_filter_skips_falsy_results(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH), + CredentialFormSchema( + variable="empty_str", label=I18nObject(en_US="Empty"), type=FormType.TEXT_INPUT, required=False + ), + ] + # Result of false switch is False. if result: is false. Not added. + # Result of empty string is "", if result: is false. Not added. + credentials = {"enabled": "false", "empty_str": ""} + result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) + assert "enabled" not in result + assert "empty_str" not in result diff --git a/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_model_credential_schema_validator.py b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_model_credential_schema_validator.py new file mode 100644 index 0000000000..bab2805276 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_model_credential_schema_validator.py @@ -0,0 +1,233 @@ +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FieldModelSchema, + FormOption, + FormShowOnObject, + FormType, + ModelCredentialSchema, +) +from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator + + +def test_validate_and_filter_with_none_schema(): + validator = ModelCredentialSchemaValidator(ModelType.LLM, None) + with pytest.raises(ValueError, match="Model credential schema is None"): + validator.validate_and_filter({}) + + +def test_validate_and_filter_success(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key", zh_Hans="API Key"), + type=FormType.SECRET_INPUT, + required=True, + ), + CredentialFormSchema( + variable="optional_field", + label=I18nObject(en_US="Optional", zh_Hans="可选"), + type=FormType.TEXT_INPUT, + required=False, + default="default_val", + ), + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + credentials = {"api_key": "sk-123456"} + result = validator.validate_and_filter(credentials) + + assert result["api_key"] == "sk-123456" + assert result["optional_field"] == "default_val" + assert credentials["__model_type"] == ModelType.LLM.value + + +def test_validate_and_filter_with_show_on(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="mode", label=I18nObject(en_US="Mode", zh_Hans="模式"), type=FormType.TEXT_INPUT, required=True + ), + CredentialFormSchema( + variable="conditional_field", + label=I18nObject(en_US="Conditional", zh_Hans="条件"), + type=FormType.TEXT_INPUT, + required=True, + show_on=[FormShowOnObject(variable="mode", value="advanced")], + ), + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + # mode is 'simple', conditional_field should be filtered out + credentials = {"mode": "simple", "conditional_field": "secret"} + result = validator.validate_and_filter(credentials) + assert "conditional_field" not in result + assert result["mode"] == "simple" + + # mode is 'advanced', conditional_field should be kept + credentials = {"mode": "advanced", "conditional_field": "secret"} + result = validator.validate_and_filter(credentials) + assert result["conditional_field"] == "secret" + assert result["mode"] == "advanced" + + # show_on variable missing in credentials + credentials = {"conditional_field": "secret"} # mode missing + with pytest.raises(ValueError, match="Variable mode is required"): # because mode is required in schema + validator.validate_and_filter(credentials) + + +def test_validate_and_filter_show_on_missing_trigger_var(): + # specifically test all_show_on_match = False when variable not in credentials + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="optional_trigger", + label=I18nObject(en_US="Optional Trigger", zh_Hans="可选触发"), + type=FormType.TEXT_INPUT, + required=False, + ), + CredentialFormSchema( + variable="conditional_field", + label=I18nObject(en_US="Conditional", zh_Hans="条件"), + type=FormType.TEXT_INPUT, + required=False, + show_on=[FormShowOnObject(variable="optional_trigger", value="active")], + ), + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + # optional_trigger missing, conditional_field should be skipped + result = validator.validate_and_filter({"conditional_field": "val"}) + assert "conditional_field" not in result + + +def test_common_validator_logic_required(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key", zh_Hans="API Key"), + type=FormType.SECRET_INPUT, + required=True, + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + with pytest.raises(ValueError, match="Variable api_key is required"): + validator.validate_and_filter({}) + + with pytest.raises(ValueError, match="Variable api_key is required"): + validator.validate_and_filter({"api_key": ""}) + + +def test_common_validator_logic_max_length(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="key", + label=I18nObject(en_US="Key", zh_Hans="Key"), + type=FormType.TEXT_INPUT, + required=True, + max_length=5, + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + with pytest.raises(ValueError, match="Variable key length should not be greater than 5"): + validator.validate_and_filter({"key": "123456"}) + + +def test_common_validator_logic_invalid_type(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="key", label=I18nObject(en_US="Key", zh_Hans="Key"), type=FormType.TEXT_INPUT, required=True + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + with pytest.raises(ValueError, match="Variable key should be string"): + validator.validate_and_filter({"key": 123}) + + +def test_common_validator_logic_switch(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="enabled", + label=I18nObject(en_US="Enabled", zh_Hans="启用"), + type=FormType.SWITCH, + required=True, + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + result = validator.validate_and_filter({"enabled": "true"}) + assert result["enabled"] is True + + result = validator.validate_and_filter({"enabled": "false"}) + assert "enabled" not in result + + with pytest.raises(ValueError, match="Variable enabled should be true or false"): + validator.validate_and_filter({"enabled": "not_a_bool"}) + + +def test_common_validator_logic_options(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="choice", + label=I18nObject(en_US="Choice", zh_Hans="选择"), + type=FormType.SELECT, + required=True, + options=[ + FormOption(label=I18nObject(en_US="A", zh_Hans="A"), value="a"), + FormOption(label=I18nObject(en_US="B", zh_Hans="B"), value="b"), + ], + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + result = validator.validate_and_filter({"choice": "a"}) + assert result["choice"] == "a" + + with pytest.raises(ValueError, match="Variable choice is not in options"): + validator.validate_and_filter({"choice": "c"}) + + +def test_validate_and_filter_optional_no_default(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="optional", + label=I18nObject(en_US="Optional", zh_Hans="可选"), + type=FormType.TEXT_INPUT, + required=False, + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + result = validator.validate_and_filter({}) + assert "optional" not in result diff --git a/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_provider_credential_schema_validator.py b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_provider_credential_schema_validator.py new file mode 100644 index 0000000000..043306840f --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_provider_credential_schema_validator.py @@ -0,0 +1,72 @@ +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderCredentialSchema +from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import ( + ProviderCredentialSchemaValidator, +) + + +class TestProviderCredentialSchemaValidator: + def test_validate_and_filter_success(self): + # Setup schema + schema = ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True + ), + CredentialFormSchema( + variable="endpoint", + label=I18nObject(en_US="Endpoint"), + type=FormType.TEXT_INPUT, + required=False, + default="https://api.example.com", + ), + ] + ) + validator = ProviderCredentialSchemaValidator(schema) + + # Test valid credentials + credentials = {"api_key": "my-secret-key"} + result = validator.validate_and_filter(credentials) + + assert result == {"api_key": "my-secret-key", "endpoint": "https://api.example.com"} + + def test_validate_and_filter_missing_required(self): + # Setup schema + schema = ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True + ) + ] + ) + validator = ProviderCredentialSchemaValidator(schema) + + # Test missing required credentials + with pytest.raises(ValueError, match="Variable api_key is required"): + validator.validate_and_filter({}) + + def test_validate_and_filter_extra_fields_filtered(self): + # Setup schema + schema = ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True + ) + ] + ) + validator = ProviderCredentialSchemaValidator(schema) + + # Test credentials with extra fields + credentials = {"api_key": "my-secret-key", "extra_field": "should-be-filtered"} + result = validator.validate_and_filter(credentials) + + assert "api_key" in result + assert "extra_field" not in result + assert result == {"api_key": "my-secret-key"} + + def test_init(self): + schema = ProviderCredentialSchema(credential_form_schemas=[]) + validator = ProviderCredentialSchemaValidator(schema) + assert validator.provider_credential_schema == schema diff --git a/api/tests/unit_tests/dify_graph/model_runtime/utils/test_encoders.py b/api/tests/unit_tests/dify_graph/model_runtime/utils/test_encoders.py new file mode 100644 index 0000000000..1ce8765a3b --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/utils/test_encoders.py @@ -0,0 +1,231 @@ +import dataclasses +import datetime +from collections import deque +from decimal import Decimal +from enum import Enum +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from pathlib import Path, PurePath +from re import compile +from typing import Any +from unittest.mock import MagicMock +from uuid import UUID + +import pytest +from pydantic import BaseModel, ConfigDict +from pydantic.networks import AnyUrl, NameEmail +from pydantic.types import SecretBytes, SecretStr +from pydantic_core import Url +from pydantic_extra_types.color import Color + +from dify_graph.model_runtime.utils.encoders import ( + _model_dump, + decimal_encoder, + generate_encoders_by_class_tuples, + isoformat, + jsonable_encoder, +) + + +class MockEnum(Enum): + A = "a" + B = "b" + + +class MockPydanticModel(BaseModel): + model_config = ConfigDict(populate_by_name=True) + name: str + age: int + + +@dataclasses.dataclass +class MockDataclass: + name: str + value: Any + + +class MockWithDict: + def __init__(self, data): + self.data = data + + def __iter__(self): + return iter(self.data.items()) + + +class MockWithVars: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class TestEncoders: + def test_model_dump(self): + model = MockPydanticModel(name="test", age=20) + result = _model_dump(model) + assert result == {"name": "test", "age": 20} + + def test_isoformat(self): + d = datetime.date(2023, 1, 1) + assert isoformat(d) == "2023-01-01" + t = datetime.time(12, 0, 0) + assert isoformat(t) == "12:00:00" + + def test_decimal_encoder(self): + assert decimal_encoder(Decimal("1.0")) == 1.0 + assert decimal_encoder(Decimal(1)) == 1 + assert decimal_encoder(Decimal("1.5")) == 1.5 + assert decimal_encoder(Decimal(0)) == 0 + assert decimal_encoder(Decimal(-1)) == -1 + + def test_generate_encoders_by_class_tuples(self): + type_map = {int: str, float: str, str: int} + result = generate_encoders_by_class_tuples(type_map) + assert result[str] == (int, float) + assert result[int] == (str,) + + def test_jsonable_encoder_basic_types(self): + assert jsonable_encoder("string") == "string" + assert jsonable_encoder(123) == 123 + assert jsonable_encoder(1.23) == 1.23 + assert jsonable_encoder(None) is None + + def test_jsonable_encoder_pydantic(self): + model = MockPydanticModel(name="test", age=20) + assert jsonable_encoder(model) == {"name": "test", "age": 20} + + def test_jsonable_encoder_pydantic_root(self): + # Manually create a mock that behaves like a model with __root__ + # because Pydantic v2 handles root differently, but the code checks for "__root__" + model = MagicMock(spec=BaseModel) + # _model_dump(obj, mode="json", ...) -> model.model_dump(mode="json", ...) + model.model_dump.return_value = {"__root__": [1, 2, 3]} + assert jsonable_encoder(model) == [1, 2, 3] + + def test_jsonable_encoder_dataclass(self): + obj = MockDataclass(name="test", value=1) + assert jsonable_encoder(obj) == {"name": "test", "value": 1} + # Test dataclass type (should not be treated as instance) + # It should fall back to vars() or dict() or at least not crash + with pytest.raises(ValueError): + jsonable_encoder(MockDataclass) + + def test_jsonable_encoder_enum(self): + assert jsonable_encoder(MockEnum.A) == "a" + + def test_jsonable_encoder_path(self): + assert jsonable_encoder(Path("/tmp/test")) == "/tmp/test" + assert jsonable_encoder(PurePath("/tmp/test")) == "/tmp/test" + + def test_jsonable_encoder_decimal(self): + # In jsonable_encoder, Decimal is formatted as string via format(obj, "f") + assert jsonable_encoder(Decimal("1.23")) == "1.23" + assert jsonable_encoder(Decimal("1.000")) == "1.000" + + def test_jsonable_encoder_dict(self): + d = {"a": 1, "b": [2, 3], "_sa_instance": "hidden"} + assert jsonable_encoder(d) == {"a": 1, "b": [2, 3]} + assert jsonable_encoder(d, sqlalchemy_safe=False) == {"a": 1, "b": [2, 3], "_sa_instance": "hidden"} + + d_with_none = {"a": 1, "b": None} + assert jsonable_encoder(d_with_none, exclude_none=True) == {"a": 1} + assert jsonable_encoder(d_with_none, exclude_none=False) == {"a": 1, "b": None} + + def test_jsonable_encoder_collections(self): + assert jsonable_encoder([1, 2]) == [1, 2] + assert jsonable_encoder((1, 2)) == [1, 2] + assert jsonable_encoder({1, 2}) == [1, 2] + assert jsonable_encoder(frozenset([1, 2])) == [1, 2] + assert jsonable_encoder(deque([1, 2])) == [1, 2] + + def gen(): + yield 1 + yield 2 + + assert jsonable_encoder(gen()) == [1, 2] + + def test_jsonable_encoder_custom_encoder(self): + custom = {int: lambda x: str(x + 1)} + assert jsonable_encoder(1, custom_encoder=custom) == "2" + + # Test subclass matching for custom encoder + class SubInt(int): + pass + + assert jsonable_encoder(SubInt(1), custom_encoder=custom) == "2" + + def test_jsonable_encoder_special_types(self): + # These hit ENCODERS_BY_TYPE or encoders_by_class_tuples + assert jsonable_encoder(b"bytes") == "bytes" + assert jsonable_encoder(Color("red")) == "red" + + dt = datetime.datetime(2023, 1, 1, 12, 0, 0) + assert jsonable_encoder(dt) == dt.isoformat() + + date = datetime.date(2023, 1, 1) + assert jsonable_encoder(date) == date.isoformat() + + time = datetime.time(12, 0, 0) + assert jsonable_encoder(time) == time.isoformat() + + td = datetime.timedelta(seconds=60) + assert jsonable_encoder(td) == 60.0 + + assert jsonable_encoder(IPv4Address("127.0.0.1")) == "127.0.0.1" + assert jsonable_encoder(IPv4Interface("127.0.0.1/24")) == "127.0.0.1/24" + assert jsonable_encoder(IPv4Network("127.0.0.0/24")) == "127.0.0.0/24" + assert jsonable_encoder(IPv6Address("::1")) == "::1" + assert jsonable_encoder(IPv6Interface("::1/128")) == "::1/128" + assert jsonable_encoder(IPv6Network("::/128")) == "::/128" + + assert jsonable_encoder(NameEmail(name="test", email="test@example.com")) == "test " + + assert jsonable_encoder(compile("abc")) == "abc" + + # Secret types + # Check what they actually return in this environment + res_bytes = jsonable_encoder(SecretBytes(b"secret")) + assert "**********" in res_bytes + + res_str = jsonable_encoder(SecretStr("secret")) + assert res_str == "**********" + + u = UUID("12345678-1234-5678-1234-567812345678") + assert jsonable_encoder(u) == str(u) + + url = AnyUrl("https://example.com") + assert jsonable_encoder(url) == "https://example.com/" + + purl = Url("https://example.com") + assert jsonable_encoder(purl) == "https://example.com/" + + def test_jsonable_encoder_fallback(self): + # dict(obj) success + obj_dict = MockWithDict({"a": 1}) + assert jsonable_encoder(obj_dict) == {"a": 1} + + # vars(obj) success + obj_vars = MockWithVars(x=10, y=20) + assert jsonable_encoder(obj_vars) == {"x": 10, "y": 20} + + # error fallback + class ReallyUnserializable: + __slots__ = ["__weakref__"] # No __dict__ + + def __iter__(self): + raise TypeError("not iterable") + + with pytest.raises(ValueError) as exc: + jsonable_encoder(ReallyUnserializable()) + assert "not iterable" in str(exc.value) + + def test_jsonable_encoder_nested(self): + data = { + "model": MockPydanticModel(name="test", age=20), + "list": [Decimal("1.1"), {MockEnum.A: Path("/tmp")}], + "set": {1, 2}, + } + expected = { + "model": {"name": "test", "age": 20}, + "list": ["1.1", {"a": "/tmp"}], + "set": [1, 2], + } + assert jsonable_encoder(data) == expected diff --git a/api/tests/unit_tests/models/test_account_models.py b/api/tests/unit_tests/models/test_account_models.py index cc311d447f..1726fc2e8b 100644 --- a/api/tests/unit_tests/models/test_account_models.py +++ b/api/tests/unit_tests/models/test_account_models.py @@ -98,7 +98,7 @@ class TestAccountModelValidation: ) # Assert - assert account.status == "active" + assert account.status == AccountStatus.ACTIVE def test_account_get_status_method(self): """Test the get_status method returns AccountStatus enum.""" @@ -106,7 +106,7 @@ class TestAccountModelValidation: account = Account( name="Test User", email="test@example.com", - status="pending", + status=AccountStatus.PENDING, ) # Act diff --git a/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py new file mode 100644 index 0000000000..c2fcd71875 --- /dev/null +++ b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType +from unittest.mock import MagicMock + +import httpx +import pytest + + +@pytest.fixture(scope="module") +def jina_module() -> ModuleType: + """ + Load `api/services/auth/jina.py` as a standalone module. + + This repo contains both `services/auth/jina.py` and a package at + `services/auth/jina/`, so importing `services.auth.jina` can be ambiguous. + """ + + module_path = Path(__file__).resolve().parents[4] / "services" / "auth" / "jina.py" + # Use a stable module name so pytest-cov can target it with `--cov=services.auth.jina_file`. + spec = importlib.util.spec_from_file_location("services.auth.jina_file", module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def _credentials(api_key: str | None = "test_api_key_123", auth_type: str = "bearer") -> dict: + config: dict = {} if api_key is None else {"api_key": api_key} + return {"auth_type": auth_type, "config": config} + + +def test_init_valid_bearer_credentials(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials()) + assert auth.api_key == "test_api_key_123" + assert auth.credentials["auth_type"] == "bearer" + + +def test_init_rejects_invalid_auth_type(jina_module: ModuleType) -> None: + with pytest.raises(ValueError, match="Invalid auth type.*Bearer"): + jina_module.JinaAuth(_credentials(auth_type="basic")) + + +@pytest.mark.parametrize("credentials", [{"auth_type": "bearer", "config": {}}, {"auth_type": "bearer"}]) +def test_init_requires_api_key(jina_module: ModuleType, credentials: dict) -> None: + with pytest.raises(ValueError, match="No API key provided"): + jina_module.JinaAuth(credentials) + + +def test_prepare_headers_includes_bearer_api_key(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + assert auth._prepare_headers() == {"Content-Type": "application/json", "Authorization": "Bearer k"} + + +def test_post_request_calls_httpx(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + post_mock = MagicMock(name="httpx.post") + monkeypatch.setattr(jina_module.httpx, "post", post_mock) + + auth._post_request("https://r.jina.ai", {"url": "https://example.com"}, {"h": "v"}) + post_mock.assert_called_once_with("https://r.jina.ai", headers={"h": "v"}, json={"url": "https://example.com"}) + + +def test_validate_credentials_success(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + + response = MagicMock() + response.status_code = 200 + post_mock = MagicMock(return_value=response) + monkeypatch.setattr(jina_module.httpx, "post", post_mock) + + assert auth.validate_credentials() is True + post_mock.assert_called_once_with( + "https://r.jina.ai", + headers={"Content-Type": "application/json", "Authorization": "Bearer k"}, + json={"url": "https://example.com"}, + ) + + +def test_validate_credentials_non_200_raises_via_handle_error( + jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch +) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + + response = MagicMock() + response.status_code = 402 + response.json.return_value = {"error": "Payment required"} + monkeypatch.setattr(jina_module.httpx, "post", MagicMock(return_value=response)) + + with pytest.raises(Exception, match="Status code: 402.*Payment required"): + auth.validate_credentials() + + +@pytest.mark.parametrize("status_code", [402, 409, 500]) +def test_handle_error_statuses_use_response_json(jina_module: ModuleType, status_code: int) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = status_code + response.json.return_value = {"error": "boom"} + + with pytest.raises(Exception, match=f"Status code: {status_code}.*boom"): + auth._handle_error(response) + + +def test_handle_error_statuses_default_unknown_error(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 402 + response.json.return_value = {} + + with pytest.raises(Exception, match="Unknown error occurred"): + auth._handle_error(response) + + +def test_handle_error_with_text_json_body(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 403 + response.text = '{"error": "Forbidden"}' + + with pytest.raises(Exception, match="Status code: 403.*Forbidden"): + auth._handle_error(response) + + +def test_handle_error_with_text_json_body_missing_error(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 403 + response.text = "{}" + + with pytest.raises(Exception, match="Unknown error occurred"): + auth._handle_error(response) + + +def test_handle_error_without_text_raises_unexpected(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 404 + response.text = "" + + with pytest.raises(Exception, match="Unexpected error occurred.*404"): + auth._handle_error(response) + + +def test_validate_credentials_propagates_network_errors( + jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch +) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + monkeypatch.setattr(jina_module.httpx, "post", MagicMock(side_effect=httpx.ConnectError("boom"))) + + with pytest.raises(httpx.ConnectError, match="boom"): + auth.validate_credentials() diff --git a/api/tests/unit_tests/services/plugin/__init__.py b/api/tests/unit_tests/services/plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/plugin/conftest.py b/api/tests/unit_tests/services/plugin/conftest.py new file mode 100644 index 0000000000..80c6077b0c --- /dev/null +++ b/api/tests/unit_tests/services/plugin/conftest.py @@ -0,0 +1,39 @@ +"""Shared fixtures for services.plugin test suite.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from services.feature_service import PluginInstallationScope + + +def make_features( + restrict_to_marketplace: bool = False, + scope: PluginInstallationScope = PluginInstallationScope.ALL, +) -> MagicMock: + """Create a mock FeatureService.get_system_features() result.""" + features = MagicMock() + features.plugin_installation_permission.restrict_to_marketplace_only = restrict_to_marketplace + features.plugin_installation_permission.plugin_installation_scope = scope + return features + + +@pytest.fixture +def mock_installer(monkeypatch): + """Patch PluginInstaller at the service import site.""" + mock = MagicMock() + monkeypatch.setattr("services.plugin.plugin_service.PluginInstaller", lambda: mock) + return mock + + +@pytest.fixture +def mock_features(): + """Patch FeatureService to return permissive defaults.""" + from unittest.mock import patch + + features = make_features() + with patch("services.plugin.plugin_service.FeatureService") as mock_fs: + mock_fs.get_system_features.return_value = features + yield features diff --git a/api/tests/unit_tests/services/plugin/test_dependencies_analysis.py b/api/tests/unit_tests/services/plugin/test_dependencies_analysis.py new file mode 100644 index 0000000000..8f0886769c --- /dev/null +++ b/api/tests/unit_tests/services/plugin/test_dependencies_analysis.py @@ -0,0 +1,172 @@ +"""Tests for services.plugin.dependencies_analysis.DependenciesAnalysisService. + +Covers: provider ID resolution, leaked dependency detection with version +extraction, dependency generation from multiple sources, and latest +dependencies via marketplace. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin import PluginDependency, PluginInstallationSource +from services.plugin.dependencies_analysis import DependenciesAnalysisService + + +class TestAnalyzeToolDependency: + def test_valid_three_part_id(self): + result = DependenciesAnalysisService.analyze_tool_dependency("langgenius/google/google") + assert result == "langgenius/google" + + def test_single_part_expands_to_langgenius(self): + result = DependenciesAnalysisService.analyze_tool_dependency("websearch") + assert result == "langgenius/websearch" + + def test_invalid_format_raises(self): + with pytest.raises(ValueError): + DependenciesAnalysisService.analyze_tool_dependency("bad/format") + + +class TestAnalyzeModelProviderDependency: + def test_valid_three_part_id(self): + result = DependenciesAnalysisService.analyze_model_provider_dependency("langgenius/openai/openai") + assert result == "langgenius/openai" + + def test_google_maps_to_gemini(self): + result = DependenciesAnalysisService.analyze_model_provider_dependency("langgenius/google/google") + assert result == "langgenius/gemini" + + def test_single_part_expands(self): + result = DependenciesAnalysisService.analyze_model_provider_dependency("anthropic") + assert result == "langgenius/anthropic" + + +class TestGetLeakedDependencies: + def _make_dependency(self, identifier: str, dep_type=PluginDependency.Type.Marketplace): + return PluginDependency( + type=dep_type, + value=PluginDependency.Marketplace(marketplace_plugin_unique_identifier=identifier), + ) + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_returns_empty_when_all_present(self, mock_installer_cls): + mock_installer_cls.return_value.fetch_missing_dependencies.return_value = [] + deps = [self._make_dependency("org/plugin:1.0.0@hash")] + + result = DependenciesAnalysisService.get_leaked_dependencies("t1", deps) + + assert result == [] + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_returns_missing_with_version_extracted(self, mock_installer_cls): + missing = MagicMock() + missing.plugin_unique_identifier = "org/plugin:1.2.3@hash" + missing.current_identifier = "org/plugin:1.0.0@oldhash" + mock_installer_cls.return_value.fetch_missing_dependencies.return_value = [missing] + + deps = [self._make_dependency("org/plugin:1.2.3@hash")] + + result = DependenciesAnalysisService.get_leaked_dependencies("t1", deps) + + assert len(result) == 1 + assert result[0].value.version == "1.2.3" + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_skips_present_dependencies(self, mock_installer_cls): + missing = MagicMock() + missing.plugin_unique_identifier = "org/missing:1.0.0@hash" + missing.current_identifier = None + mock_installer_cls.return_value.fetch_missing_dependencies.return_value = [missing] + + deps = [ + self._make_dependency("org/present:1.0.0@hash"), + self._make_dependency("org/missing:1.0.0@hash"), + ] + + result = DependenciesAnalysisService.get_leaked_dependencies("t1", deps) + + assert len(result) == 1 + + +class TestGenerateDependencies: + def _make_installation(self, source, identifier, meta=None): + install = MagicMock() + install.source = source + install.plugin_unique_identifier = identifier + install.meta = meta or {} + return install + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_github_source(self, mock_installer_cls): + install = self._make_installation( + PluginInstallationSource.Github, + "org/plugin:1.0.0@hash", + {"repo": "org/repo", "version": "v1.0", "package": "plugin.difypkg"}, + ) + mock_installer_cls.return_value.fetch_plugin_installation_by_ids.return_value = [install] + + result = DependenciesAnalysisService.generate_dependencies("t1", ["p1"]) + + assert len(result) == 1 + assert result[0].type == PluginDependency.Type.Github + assert result[0].value.repo == "org/repo" + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_marketplace_source(self, mock_installer_cls): + install = self._make_installation(PluginInstallationSource.Marketplace, "org/plugin:1.0.0@hash") + mock_installer_cls.return_value.fetch_plugin_installation_by_ids.return_value = [install] + + result = DependenciesAnalysisService.generate_dependencies("t1", ["p1"]) + + assert result[0].type == PluginDependency.Type.Marketplace + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_package_source(self, mock_installer_cls): + install = self._make_installation(PluginInstallationSource.Package, "org/plugin:1.0.0@hash") + mock_installer_cls.return_value.fetch_plugin_installation_by_ids.return_value = [install] + + result = DependenciesAnalysisService.generate_dependencies("t1", ["p1"]) + + assert result[0].type == PluginDependency.Type.Package + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_remote_source_raises(self, mock_installer_cls): + install = self._make_installation(PluginInstallationSource.Remote, "org/plugin:1.0.0@hash") + mock_installer_cls.return_value.fetch_plugin_installation_by_ids.return_value = [install] + + with pytest.raises(ValueError, match="remote plugin"): + DependenciesAnalysisService.generate_dependencies("t1", ["p1"]) + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_deduplicates_input_ids(self, mock_installer_cls): + mock_installer_cls.return_value.fetch_plugin_installation_by_ids.return_value = [] + + DependenciesAnalysisService.generate_dependencies("t1", ["p1", "p1", "p2"]) + + call_args = mock_installer_cls.return_value.fetch_plugin_installation_by_ids.call_args[0] + assert len(call_args[1]) == 2 # deduplicated + + +class TestGenerateLatestDependencies: + @patch("services.plugin.dependencies_analysis.dify_config") + def test_returns_empty_when_marketplace_disabled(self, mock_config): + mock_config.MARKETPLACE_ENABLED = False + + result = DependenciesAnalysisService.generate_latest_dependencies(["p1"]) + + assert result == [] + + @patch("services.plugin.dependencies_analysis.marketplace") + @patch("services.plugin.dependencies_analysis.dify_config") + def test_returns_marketplace_deps_when_enabled(self, mock_config, mock_marketplace): + mock_config.MARKETPLACE_ENABLED = True + manifest = MagicMock() + manifest.latest_package_identifier = "org/plugin:2.0.0@newhash" + mock_marketplace.batch_fetch_plugin_manifests.return_value = [manifest] + + result = DependenciesAnalysisService.generate_latest_dependencies(["p1"]) + + assert len(result) == 1 + assert result[0].type == PluginDependency.Type.Marketplace diff --git a/api/tests/unit_tests/services/plugin/test_endpoint_service.py b/api/tests/unit_tests/services/plugin/test_endpoint_service.py new file mode 100644 index 0000000000..ddf80c8017 --- /dev/null +++ b/api/tests/unit_tests/services/plugin/test_endpoint_service.py @@ -0,0 +1,41 @@ +"""Tests for services.plugin.endpoint_service.EndpointService. + +Smoke tests to confirm delegation to PluginEndpointClient. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from services.plugin.endpoint_service import EndpointService + + +class TestEndpointServiceDelegation: + @patch("services.plugin.endpoint_service.PluginEndpointClient") + def test_create_delegates_correctly(self, mock_client_cls): + expected = MagicMock() + mock_client_cls.return_value.create_endpoint.return_value = expected + + result = EndpointService.create_endpoint("t1", "u1", "uid-1", "my-endpoint", {"key": "val"}) + + assert result is expected + mock_client_cls.return_value.create_endpoint.assert_called_once_with( + tenant_id="t1", user_id="u1", plugin_unique_identifier="uid-1", name="my-endpoint", settings={"key": "val"} + ) + + @patch("services.plugin.endpoint_service.PluginEndpointClient") + def test_list_delegates_correctly(self, mock_client_cls): + expected = MagicMock() + mock_client_cls.return_value.list_endpoints.return_value = expected + + result = EndpointService.list_endpoints("t1", "u1", 1, 10) + + assert result is expected + + @patch("services.plugin.endpoint_service.PluginEndpointClient") + def test_enable_disable_delegates(self, mock_client_cls): + EndpointService.enable_endpoint("t1", "u1", "ep-1") + mock_client_cls.return_value.enable_endpoint.assert_called_once() + + EndpointService.disable_endpoint("t1", "u1", "ep-2") + mock_client_cls.return_value.disable_endpoint.assert_called_once() diff --git a/api/tests/unit_tests/services/plugin/test_oauth_service.py b/api/tests/unit_tests/services/plugin/test_oauth_service.py new file mode 100644 index 0000000000..27df4556bc --- /dev/null +++ b/api/tests/unit_tests/services/plugin/test_oauth_service.py @@ -0,0 +1,90 @@ +"""Tests for services.plugin.oauth_service.OAuthProxyService. + +Covers: CSRF proxy context creation with Redis TTL, context consumption +with one-time use semantics, and validation error paths. +""" + +from __future__ import annotations + +import json + +import pytest + +from services.plugin.oauth_service import OAuthProxyService + + +class TestCreateProxyContext: + def test_stores_context_in_redis_with_ttl(self): + context_id = OAuthProxyService.create_proxy_context( + user_id="u1", tenant_id="t1", plugin_id="p1", provider="github" + ) + + assert context_id # non-empty UUID string + from extensions.ext_redis import redis_client + + redis_client.setex.assert_called_once() + call_args = redis_client.setex.call_args + key = call_args[0][0] + ttl = call_args[0][1] + stored_data = json.loads(call_args[0][2]) + + assert key.startswith("oauth_proxy_context:") + assert ttl == 5 * 60 + assert stored_data["user_id"] == "u1" + assert stored_data["tenant_id"] == "t1" + assert stored_data["plugin_id"] == "p1" + assert stored_data["provider"] == "github" + + def test_includes_credential_id_when_provided(self): + OAuthProxyService.create_proxy_context( + user_id="u1", tenant_id="t1", plugin_id="p1", provider="github", credential_id="cred-1" + ) + + from extensions.ext_redis import redis_client + + stored_data = json.loads(redis_client.setex.call_args[0][2]) + assert stored_data["credential_id"] == "cred-1" + + def test_excludes_credential_id_when_none(self): + OAuthProxyService.create_proxy_context(user_id="u1", tenant_id="t1", plugin_id="p1", provider="github") + + from extensions.ext_redis import redis_client + + stored_data = json.loads(redis_client.setex.call_args[0][2]) + assert "credential_id" not in stored_data + + def test_includes_extra_data(self): + OAuthProxyService.create_proxy_context( + user_id="u1", tenant_id="t1", plugin_id="p1", provider="github", extra_data={"scope": "repo"} + ) + + from extensions.ext_redis import redis_client + + stored_data = json.loads(redis_client.setex.call_args[0][2]) + assert stored_data["scope"] == "repo" + + +class TestUseProxyContext: + def test_raises_when_context_id_empty(self): + with pytest.raises(ValueError, match="context_id is required"): + OAuthProxyService.use_proxy_context("") + + def test_raises_when_context_not_found(self): + from extensions.ext_redis import redis_client + + redis_client.get.return_value = None + + with pytest.raises(ValueError, match="context_id is invalid"): + OAuthProxyService.use_proxy_context("nonexistent-id") + + def test_returns_data_and_deletes_key(self): + from extensions.ext_redis import redis_client + + stored = {"user_id": "u1", "tenant_id": "t1", "plugin_id": "p1", "provider": "github"} + redis_client.get.return_value = json.dumps(stored).encode() + + result = OAuthProxyService.use_proxy_context("valid-id") + + assert result == stored + expected_key = "oauth_proxy_context:valid-id" + redis_client.delete.assert_called_once_with(expected_key) diff --git a/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py b/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py new file mode 100644 index 0000000000..bfa9fe976b --- /dev/null +++ b/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py @@ -0,0 +1,216 @@ +"""Tests for services.plugin.plugin_parameter_service.PluginParameterService. + +Covers: dynamic select options via tool and trigger credential paths, +HIDDEN_VALUE replacement, and error handling for missing records. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from services.plugin.plugin_parameter_service import PluginParameterService + + +class TestGetDynamicSelectOptionsTool: + @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") + @patch("services.plugin.plugin_parameter_service.ToolManager") + def test_no_credentials_needed(self, mock_tool_mgr, mock_client_cls): + provider_ctrl = MagicMock() + provider_ctrl.need_credentials = False + mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt1"] + + result = PluginParameterService.get_dynamic_select_options( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id=None, + provider_type="tool", + ) + + assert result == ["opt1"] + call_kwargs = mock_client_cls.return_value.fetch_dynamic_select_options.call_args + assert call_kwargs[0][5] == {} # empty credentials + + @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") + @patch("services.plugin.plugin_parameter_service.create_tool_provider_encrypter") + @patch("services.plugin.plugin_parameter_service.db") + @patch("services.plugin.plugin_parameter_service.ToolManager") + def test_fetches_credentials_with_credential_id(self, mock_tool_mgr, mock_db, mock_encrypter_fn, mock_client_cls): + provider_ctrl = MagicMock() + provider_ctrl.need_credentials = True + mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl + encrypter = MagicMock() + encrypter.decrypt.return_value = {"api_key": "decrypted"} + mock_encrypter_fn.return_value = (encrypter, None) + + # Mock the Session/query chain + db_record = MagicMock() + db_record.credentials = {"api_key": "encrypted"} + db_record.credential_type = "api_key" + + with patch("services.plugin.plugin_parameter_service.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) + mock_session.query.return_value.where.return_value.first.return_value = db_record + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt1"] + + result = PluginParameterService.get_dynamic_select_options( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id="cred-1", + provider_type="tool", + ) + + assert result == ["opt1"] + + @patch("services.plugin.plugin_parameter_service.create_tool_provider_encrypter") + @patch("services.plugin.plugin_parameter_service.db") + @patch("services.plugin.plugin_parameter_service.ToolManager") + def test_raises_when_tool_provider_not_found(self, mock_tool_mgr, mock_db, mock_encrypter_fn): + provider_ctrl = MagicMock() + provider_ctrl.need_credentials = True + mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl + mock_encrypter_fn.return_value = (MagicMock(), None) + + with patch("services.plugin.plugin_parameter_service.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) + mock_session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + with pytest.raises(ValueError, match="not found"): + PluginParameterService.get_dynamic_select_options( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id=None, + provider_type="tool", + ) + + +class TestGetDynamicSelectOptionsTrigger: + @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") + @patch("services.plugin.plugin_parameter_service.TriggerSubscriptionBuilderService") + def test_uses_subscription_builder_when_credential_id(self, mock_builder_svc, mock_client_cls): + sub = MagicMock() + sub.credentials = {"token": "abc"} + sub.credential_type = "api_key" + mock_builder_svc.get_subscription_builder.return_value = sub + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt"] + + result = PluginParameterService.get_dynamic_select_options( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="github", + action="on_push", + parameter="branch", + credential_id="builder-1", + provider_type="trigger", + ) + + assert result == ["opt"] + + @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") + @patch("services.plugin.plugin_parameter_service.TriggerProviderService") + @patch("services.plugin.plugin_parameter_service.TriggerSubscriptionBuilderService") + def test_falls_back_to_trigger_service(self, mock_builder_svc, mock_provider_svc, mock_client_cls): + mock_builder_svc.get_subscription_builder.return_value = None + trigger_sub = MagicMock() + api_entity = MagicMock() + api_entity.credentials = {"token": "abc"} + api_entity.credential_type = "api_key" + trigger_sub.to_api_entity.return_value = api_entity + mock_provider_svc.get_subscription_by_id.return_value = trigger_sub + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt"] + + result = PluginParameterService.get_dynamic_select_options( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="github", + action="on_push", + parameter="branch", + credential_id="sub-1", + provider_type="trigger", + ) + + assert result == ["opt"] + + @patch("services.plugin.plugin_parameter_service.TriggerProviderService") + @patch("services.plugin.plugin_parameter_service.TriggerSubscriptionBuilderService") + def test_raises_when_no_subscription_found(self, mock_builder_svc, mock_provider_svc): + mock_builder_svc.get_subscription_builder.return_value = None + mock_provider_svc.get_subscription_by_id.return_value = None + + with pytest.raises(ValueError, match="not found"): + PluginParameterService.get_dynamic_select_options( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="github", + action="on_push", + parameter="branch", + credential_id="nonexistent", + provider_type="trigger", + ) + + +class TestGetDynamicSelectOptionsWithCredentials: + @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") + @patch("services.plugin.plugin_parameter_service.TriggerProviderService") + def test_replaces_hidden_values(self, mock_provider_svc, mock_client_cls): + from constants import HIDDEN_VALUE + + original = MagicMock() + original.credentials = {"token": "real-secret", "name": "real-name"} + original.credential_type = "api_key" + mock_provider_svc.get_subscription_by_id.return_value = original + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt"] + + result = PluginParameterService.get_dynamic_select_options_with_credentials( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="github", + action="on_push", + parameter="branch", + credential_id="cred-1", + credentials={"token": HIDDEN_VALUE, "name": "new-name"}, + ) + + assert result == ["opt"] + call_args = mock_client_cls.return_value.fetch_dynamic_select_options.call_args[0] + resolved = call_args[5] + assert resolved["token"] == "real-secret" # replaced + assert resolved["name"] == "new-name" # kept as-is + + @patch("services.plugin.plugin_parameter_service.TriggerProviderService") + def test_raises_when_subscription_not_found(self, mock_provider_svc): + mock_provider_svc.get_subscription_by_id.return_value = None + + with pytest.raises(ValueError, match="not found"): + PluginParameterService.get_dynamic_select_options_with_credentials( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="github", + action="on_push", + parameter="branch", + credential_id="nonexistent", + credentials={"token": "val"}, + ) diff --git a/api/tests/unit_tests/services/plugin/test_plugin_service.py b/api/tests/unit_tests/services/plugin/test_plugin_service.py new file mode 100644 index 0000000000..09b9ab498b --- /dev/null +++ b/api/tests/unit_tests/services/plugin/test_plugin_service.py @@ -0,0 +1,357 @@ +"""Tests for services.plugin.plugin_service.PluginService. + +Covers: version caching with Redis, install permission/scope gates, +icon URL construction, asset retrieval with MIME guessing, plugin +verification, marketplace upgrade flows, and uninstall with credential cleanup. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin import PluginInstallationSource +from core.plugin.entities.plugin_daemon import PluginVerification +from services.errors.plugin import PluginInstallationForbiddenError +from services.feature_service import PluginInstallationScope +from services.plugin.plugin_service import PluginService +from tests.unit_tests.services.plugin.conftest import make_features + + +class TestFetchLatestPluginVersion: + @patch("services.plugin.plugin_service.marketplace") + @patch("services.plugin.plugin_service.redis_client") + def test_returns_cached_version(self, mock_redis, mock_marketplace): + cached_json = PluginService.LatestPluginCache( + plugin_id="p1", + version="1.0.0", + unique_identifier="uid-1", + status="active", + deprecated_reason="", + alternative_plugin_id="", + ).model_dump_json() + mock_redis.get.return_value = cached_json + + result = PluginService.fetch_latest_plugin_version(["p1"]) + + assert result["p1"].version == "1.0.0" + mock_marketplace.batch_fetch_plugin_manifests.assert_not_called() + + @patch("services.plugin.plugin_service.marketplace") + @patch("services.plugin.plugin_service.redis_client") + def test_fetches_from_marketplace_on_cache_miss(self, mock_redis, mock_marketplace): + mock_redis.get.return_value = None + manifest = MagicMock() + manifest.plugin_id = "p1" + manifest.latest_version = "2.0.0" + manifest.latest_package_identifier = "uid-2" + manifest.status = "active" + manifest.deprecated_reason = "" + manifest.alternative_plugin_id = "" + mock_marketplace.batch_fetch_plugin_manifests.return_value = [manifest] + + result = PluginService.fetch_latest_plugin_version(["p1"]) + + assert result["p1"].version == "2.0.0" + mock_redis.setex.assert_called_once() + + @patch("services.plugin.plugin_service.marketplace") + @patch("services.plugin.plugin_service.redis_client") + def test_returns_none_for_unknown_plugin(self, mock_redis, mock_marketplace): + mock_redis.get.return_value = None + mock_marketplace.batch_fetch_plugin_manifests.return_value = [] + + result = PluginService.fetch_latest_plugin_version(["unknown"]) + + assert result["unknown"] is None + + @patch("services.plugin.plugin_service.marketplace") + @patch("services.plugin.plugin_service.redis_client") + def test_handles_marketplace_exception_gracefully(self, mock_redis, mock_marketplace): + mock_redis.get.return_value = None + mock_marketplace.batch_fetch_plugin_manifests.side_effect = RuntimeError("network error") + + result = PluginService.fetch_latest_plugin_version(["p1"]) + + assert result == {} + + +class TestCheckMarketplaceOnlyPermission: + @patch("services.plugin.plugin_service.FeatureService") + def test_raises_when_restricted(self, mock_fs): + mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=True) + + with pytest.raises(PluginInstallationForbiddenError): + PluginService._check_marketplace_only_permission() + + @patch("services.plugin.plugin_service.FeatureService") + def test_passes_when_not_restricted(self, mock_fs): + mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=False) + + PluginService._check_marketplace_only_permission() # should not raise + + +class TestCheckPluginInstallationScope: + @patch("services.plugin.plugin_service.FeatureService") + def test_official_only_allows_langgenius(self, mock_fs): + mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) + verification = MagicMock() + verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius + + PluginService._check_plugin_installation_scope(verification) # should not raise + + @patch("services.plugin.plugin_service.FeatureService") + def test_official_only_rejects_third_party(self, mock_fs): + mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) + + with pytest.raises(PluginInstallationForbiddenError): + PluginService._check_plugin_installation_scope(None) + + @patch("services.plugin.plugin_service.FeatureService") + def test_official_and_partners_allows_partner(self, mock_fs): + mock_fs.get_system_features.return_value = make_features( + scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS + ) + verification = MagicMock() + verification.authorized_category = PluginVerification.AuthorizedCategory.Partner + + PluginService._check_plugin_installation_scope(verification) # should not raise + + @patch("services.plugin.plugin_service.FeatureService") + def test_official_and_partners_rejects_none(self, mock_fs): + mock_fs.get_system_features.return_value = make_features( + scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS + ) + + with pytest.raises(PluginInstallationForbiddenError): + PluginService._check_plugin_installation_scope(None) + + @patch("services.plugin.plugin_service.FeatureService") + def test_none_scope_always_raises(self, mock_fs): + mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.NONE) + verification = MagicMock() + verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius + + with pytest.raises(PluginInstallationForbiddenError): + PluginService._check_plugin_installation_scope(verification) + + @patch("services.plugin.plugin_service.FeatureService") + def test_all_scope_passes_any(self, mock_fs): + mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.ALL) + + PluginService._check_plugin_installation_scope(None) # should not raise + + +class TestGetPluginIconUrl: + @patch("services.plugin.plugin_service.dify_config") + def test_constructs_url_with_params(self, mock_config): + mock_config.CONSOLE_API_URL = "https://console.example.com" + + url = PluginService.get_plugin_icon_url("tenant-1", "icon.svg") + + assert "tenant_id=tenant-1" in url + assert "filename=icon.svg" in url + assert "/plugin/icon" in url + + +class TestGetAsset: + @patch("services.plugin.plugin_service.PluginAssetManager") + def test_returns_bytes_and_guessed_mime(self, mock_asset_cls): + mock_asset_cls.return_value.fetch_asset.return_value = b"" + + data, mime = PluginService.get_asset("t1", "icon.svg") + + assert data == b"" + assert "svg" in mime + + @patch("services.plugin.plugin_service.PluginAssetManager") + def test_fallback_to_octet_stream_for_unknown(self, mock_asset_cls): + mock_asset_cls.return_value.fetch_asset.return_value = b"\x00" + + _, mime = PluginService.get_asset("t1", "unknown_file") + + assert mime == "application/octet-stream" + + +class TestIsPluginVerified: + @patch("services.plugin.plugin_service.PluginInstaller") + def test_returns_true_when_verified(self, mock_installer_cls): + mock_installer_cls.return_value.fetch_plugin_manifest.return_value.verified = True + + assert PluginService.is_plugin_verified("t1", "uid-1") is True + + @patch("services.plugin.plugin_service.PluginInstaller") + def test_returns_false_on_exception(self, mock_installer_cls): + mock_installer_cls.return_value.fetch_plugin_manifest.side_effect = RuntimeError("not found") + + assert PluginService.is_plugin_verified("t1", "uid-1") is False + + +class TestUpgradePluginWithMarketplace: + @patch("services.plugin.plugin_service.dify_config") + def test_raises_when_marketplace_disabled(self, mock_config): + mock_config.MARKETPLACE_ENABLED = False + + with pytest.raises(ValueError, match="marketplace is not enabled"): + PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid") + + @patch("services.plugin.plugin_service.dify_config") + def test_raises_when_same_identifier(self, mock_config): + mock_config.MARKETPLACE_ENABLED = True + + with pytest.raises(ValueError, match="same plugin"): + PluginService.upgrade_plugin_with_marketplace("t1", "same-uid", "same-uid") + + @patch("services.plugin.plugin_service.marketplace") + @patch("services.plugin.plugin_service.FeatureService") + @patch("services.plugin.plugin_service.PluginInstaller") + @patch("services.plugin.plugin_service.dify_config") + def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace): + mock_config.MARKETPLACE_ENABLED = True + mock_fs.get_system_features.return_value = make_features() + installer = mock_installer_cls.return_value + installer.fetch_plugin_manifest.return_value = MagicMock() # no exception = already installed + installer.upgrade_plugin.return_value = MagicMock() + + PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid") + + mock_marketplace.record_install_plugin_event.assert_called_once_with("new-uid") + installer.upgrade_plugin.assert_called_once() + + @patch("services.plugin.plugin_service.download_plugin_pkg") + @patch("services.plugin.plugin_service.FeatureService") + @patch("services.plugin.plugin_service.PluginInstaller") + @patch("services.plugin.plugin_service.dify_config") + def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download): + mock_config.MARKETPLACE_ENABLED = True + mock_fs.get_system_features.return_value = make_features() + installer = mock_installer_cls.return_value + installer.fetch_plugin_manifest.side_effect = RuntimeError("not found") + mock_download.return_value = b"pkg-bytes" + upload_resp = MagicMock() + upload_resp.verification = None + installer.upload_pkg.return_value = upload_resp + installer.upgrade_plugin.return_value = MagicMock() + + PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid") + + mock_download.assert_called_once_with("new-uid") + installer.upload_pkg.assert_called_once() + + +class TestUpgradePluginWithGithub: + @patch("services.plugin.plugin_service.FeatureService") + @patch("services.plugin.plugin_service.PluginInstaller") + def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs): + mock_fs.get_system_features.return_value = make_features() + installer = mock_installer_cls.return_value + installer.upgrade_plugin.return_value = MagicMock() + + PluginService.upgrade_plugin_with_github("t1", "old-uid", "new-uid", "org/repo", "v1", "pkg.difypkg") + + installer.upgrade_plugin.assert_called_once() + call_args = installer.upgrade_plugin.call_args + assert call_args[0][3] == PluginInstallationSource.Github + + +class TestUploadPkg: + @patch("services.plugin.plugin_service.FeatureService") + @patch("services.plugin.plugin_service.PluginInstaller") + def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs): + mock_fs.get_system_features.return_value = make_features() + upload_resp = MagicMock() + upload_resp.verification = None + mock_installer_cls.return_value.upload_pkg.return_value = upload_resp + + result = PluginService.upload_pkg("t1", b"pkg-bytes") + + assert result is upload_resp + + +class TestInstallFromMarketplacePkg: + @patch("services.plugin.plugin_service.dify_config") + def test_raises_when_marketplace_disabled(self, mock_config): + mock_config.MARKETPLACE_ENABLED = False + + with pytest.raises(ValueError, match="marketplace is not enabled"): + PluginService.install_from_marketplace_pkg("t1", ["uid-1"]) + + @patch("services.plugin.plugin_service.download_plugin_pkg") + @patch("services.plugin.plugin_service.FeatureService") + @patch("services.plugin.plugin_service.PluginInstaller") + @patch("services.plugin.plugin_service.dify_config") + def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download): + mock_config.MARKETPLACE_ENABLED = True + mock_fs.get_system_features.return_value = make_features() + installer = mock_installer_cls.return_value + installer.fetch_plugin_manifest.side_effect = RuntimeError("not found") + mock_download.return_value = b"pkg" + upload_resp = MagicMock() + upload_resp.verification = None + upload_resp.unique_identifier = "resolved-uid" + installer.upload_pkg.return_value = upload_resp + installer.install_from_identifiers.return_value = "task-id" + + result = PluginService.install_from_marketplace_pkg("t1", ["uid-1"]) + + assert result == "task-id" + installer.install_from_identifiers.assert_called_once() + call_args = installer.install_from_identifiers.call_args[0] + assert call_args[1] == ["resolved-uid"] # uses response uid, not input + + @patch("services.plugin.plugin_service.FeatureService") + @patch("services.plugin.plugin_service.PluginInstaller") + @patch("services.plugin.plugin_service.dify_config") + def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs): + mock_config.MARKETPLACE_ENABLED = True + mock_fs.get_system_features.return_value = make_features() + installer = mock_installer_cls.return_value + installer.fetch_plugin_manifest.return_value = MagicMock() + decode_resp = MagicMock() + decode_resp.verification = None + installer.decode_plugin_from_identifier.return_value = decode_resp + installer.install_from_identifiers.return_value = "task-id" + + PluginService.install_from_marketplace_pkg("t1", ["uid-1"]) + + installer.install_from_identifiers.assert_called_once() + call_args = installer.install_from_identifiers.call_args[0] + assert call_args[1] == ["uid-1"] # uses original uid + + +class TestUninstall: + @patch("services.plugin.plugin_service.PluginInstaller") + def test_direct_uninstall_when_plugin_not_found(self, mock_installer_cls): + installer = mock_installer_cls.return_value + installer.list_plugins.return_value = [] + installer.uninstall.return_value = True + + result = PluginService.uninstall("t1", "install-1") + + assert result is True + installer.uninstall.assert_called_once_with("t1", "install-1") + + @patch("services.plugin.plugin_service.db") + @patch("services.plugin.plugin_service.PluginInstaller") + def test_cleans_credentials_when_plugin_found(self, mock_installer_cls, mock_db): + plugin = MagicMock() + plugin.installation_id = "install-1" + plugin.plugin_id = "org/myplugin" + installer = mock_installer_cls.return_value + installer.list_plugins.return_value = [plugin] + installer.uninstall.return_value = True + + # Mock Session context manager + mock_session = MagicMock() + mock_db.engine = MagicMock() + mock_session.scalars.return_value.all.return_value = [] # no credentials found + + with patch("services.plugin.plugin_service.Session") as mock_session_cls: + mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) + + result = PluginService.uninstall("t1", "install-1") + + assert result is True + installer.uninstall.assert_called_once() diff --git a/api/tests/unit_tests/services/recommend_app/__init__.py b/api/tests/unit_tests/services/recommend_app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py new file mode 100644 index 0000000000..770344aa39 --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py @@ -0,0 +1,91 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.recommend_app_type import RecommendAppType + +SAMPLE_BUILTIN_DATA = { + "recommended_apps": { + "en-US": {"categories": ["writing"], "apps": [{"id": "app-1"}]}, + "zh-Hans": {"categories": ["search"], "apps": [{"id": "app-2"}]}, + }, + "app_details": { + "app-1": {"id": "app-1", "name": "Writer", "mode": "chat"}, + "app-2": {"id": "app-2", "name": "Searcher", "mode": "workflow"}, + }, +} + + +@pytest.fixture(autouse=True) +def _reset_cache(): + BuildInRecommendAppRetrieval.builtin_data = None + yield + BuildInRecommendAppRetrieval.builtin_data = None + + +class TestBuildInRecommendAppRetrieval: + def test_get_type(self): + retrieval = BuildInRecommendAppRetrieval() + assert retrieval.get_type() == RecommendAppType.BUILDIN + + def test_get_recommended_apps_and_categories_delegates(self): + with patch.object( + BuildInRecommendAppRetrieval, + "fetch_recommended_apps_from_builtin", + return_value={"apps": []}, + ) as mock_fetch: + retrieval = BuildInRecommendAppRetrieval() + result = retrieval.get_recommended_apps_and_categories("en-US") + mock_fetch.assert_called_once_with("en-US") + assert result == {"apps": []} + + def test_get_recommend_app_detail_delegates(self): + with patch.object( + BuildInRecommendAppRetrieval, + "fetch_recommended_app_detail_from_builtin", + return_value={"id": "app-1"}, + ) as mock_fetch: + retrieval = BuildInRecommendAppRetrieval() + result = retrieval.get_recommend_app_detail("app-1") + mock_fetch.assert_called_once_with("app-1") + assert result == {"id": "app-1"} + + def test_get_builtin_data_reads_json_and_caches(self, tmp_path): + json_file = tmp_path / "constants" / "recommended_apps.json" + json_file.parent.mkdir(parents=True) + json_file.write_text(json.dumps(SAMPLE_BUILTIN_DATA)) + + mock_app = MagicMock() + mock_app.root_path = str(tmp_path) + + with patch( + "services.recommend_app.buildin.buildin_retrieval.current_app", + mock_app, + ): + first = BuildInRecommendAppRetrieval._get_builtin_data() + second = BuildInRecommendAppRetrieval._get_builtin_data() + + assert first == SAMPLE_BUILTIN_DATA + assert first is second + + def test_fetch_recommended_apps_from_builtin(self): + BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA + result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin("en-US") + assert result == SAMPLE_BUILTIN_DATA["recommended_apps"]["en-US"] + + def test_fetch_recommended_apps_from_builtin_missing_language(self): + BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA + result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin("fr-FR") + assert result == {} + + def test_fetch_recommended_app_detail_from_builtin(self): + BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA + result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin("app-1") + assert result == {"id": "app-1", "name": "Writer", "mode": "chat"} + + def test_fetch_recommended_app_detail_from_builtin_missing(self): + BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA + result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin("nonexistent") + assert result is None diff --git a/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py new file mode 100644 index 0000000000..5d21665f75 --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py @@ -0,0 +1,145 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval +from services.recommend_app.recommend_app_type import RecommendAppType + + +class TestDatabaseRecommendAppRetrieval: + def test_get_type(self): + assert DatabaseRecommendAppRetrieval().get_type() == RecommendAppType.DATABASE + + def test_get_recommended_apps_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_apps_from_db", + return_value={"recommended_apps": [], "categories": []}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") + mock_fetch.assert_called_once_with("en-US") + assert result == {"recommended_apps": [], "categories": []} + + def test_get_recommend_app_detail_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_app_detail_from_db", + return_value={"id": "app-1"}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommend_app_detail("app-1") + mock_fetch.assert_called_once_with("app-1") + assert result == {"id": "app-1"} + + +class TestFetchRecommendedAppsFromDb: + def _make_recommended_app(self, app_id, category, is_public=True, has_site=True): + site = ( + SimpleNamespace( + description="desc", + copyright="copy", + privacy_policy="pp", + custom_disclaimer="cd", + ) + if has_site + else None + ) + app = ( + SimpleNamespace(is_public=is_public, site=site) + if is_public + else SimpleNamespace(is_public=False, site=site) + ) + return SimpleNamespace( + id=f"rec-{app_id}", + app=app, + app_id=app_id, + category=category, + position=1, + is_listed=True, + ) + + @patch("services.recommend_app.database.database_retrieval.db") + def test_returns_apps_and_sorted_categories(self, mock_db): + rec1 = self._make_recommended_app("a1", "writing") + rec2 = self._make_recommended_app("a2", "assistant") + mock_db.session.scalars.return_value.all.return_value = [rec1, rec2] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + assert len(result["recommended_apps"]) == 2 + assert result["categories"] == ["assistant", "writing"] + + @patch("services.recommend_app.database.database_retrieval.db") + def test_falls_back_to_default_language_when_empty(self, mock_db): + mock_db.session.scalars.return_value.all.side_effect = [ + [], + [self._make_recommended_app("a1", "chat")], + ] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("fr-FR") + + assert len(result["recommended_apps"]) == 1 + assert mock_db.session.scalars.call_count == 2 + + @patch("services.recommend_app.database.database_retrieval.db") + def test_skips_non_public_apps(self, mock_db): + rec = self._make_recommended_app("a1", "chat", is_public=False) + mock_db.session.scalars.return_value.all.return_value = [rec] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + assert result["recommended_apps"] == [] + + @patch("services.recommend_app.database.database_retrieval.db") + def test_skips_apps_without_site(self, mock_db): + rec = self._make_recommended_app("a1", "chat", has_site=False) + mock_db.session.scalars.return_value.all.return_value = [rec] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + assert result["recommended_apps"] == [] + + +class TestFetchRecommendedAppDetailFromDb: + @patch("services.recommend_app.database.database_retrieval.db") + def test_returns_none_when_not_listed(self, mock_db): + mock_db.session.query.return_value.where.return_value.first.return_value = None + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") + + assert result is None + + @patch("services.recommend_app.database.database_retrieval.AppDslService") + @patch("services.recommend_app.database.database_retrieval.db") + def test_returns_none_when_app_not_public(self, mock_db, mock_dsl): + rec_chain = MagicMock() + rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1") + app_chain = MagicMock() + app_chain.where.return_value.first.return_value = SimpleNamespace(id="app-1", is_public=False) + mock_db.session.query.side_effect = [rec_chain, app_chain] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") + + assert result is None + + @patch("services.recommend_app.database.database_retrieval.AppDslService") + @patch("services.recommend_app.database.database_retrieval.db") + def test_returns_detail_on_success(self, mock_db, mock_dsl): + app_model = SimpleNamespace( + id="app-1", + name="My App", + icon="icon.png", + icon_background="#fff", + mode="chat", + is_public=True, + ) + rec_chain = MagicMock() + rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1") + app_chain = MagicMock() + app_chain.where.return_value.first.return_value = app_model + mock_db.session.query.side_effect = [rec_chain, app_chain] + mock_dsl.export_dsl.return_value = "exported_yaml" + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") + + assert result["id"] == "app-1" + assert result["name"] == "My App" + assert result["export_data"] == "exported_yaml" diff --git a/api/tests/unit_tests/services/recommend_app/test_recommend_app_factory.py b/api/tests/unit_tests/services/recommend_app/test_recommend_app_factory.py new file mode 100644 index 0000000000..036cba0cc0 --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_recommend_app_factory.py @@ -0,0 +1,28 @@ +import pytest + +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval +from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory +from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval + + +class TestRecommendAppRetrievalFactory: + @pytest.mark.parametrize( + ("mode", "expected_class"), + [ + ("remote", RemoteRecommendAppRetrieval), + ("builtin", BuildInRecommendAppRetrieval), + ("db", DatabaseRecommendAppRetrieval), + ], + ) + def test_factory_returns_correct_class(self, mode, expected_class): + result = RecommendAppRetrievalFactory.get_recommend_app_factory(mode) + assert result is expected_class + + def test_factory_raises_for_unknown_mode(self): + with pytest.raises(ValueError, match="invalid fetch recommended apps mode"): + RecommendAppRetrievalFactory.get_recommend_app_factory("invalid_mode") + + def test_get_buildin_recommend_app_retrieval(self): + result = RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval() + assert result is BuildInRecommendAppRetrieval diff --git a/api/tests/unit_tests/services/recommend_app/test_recommend_app_type.py b/api/tests/unit_tests/services/recommend_app/test_recommend_app_type.py new file mode 100644 index 0000000000..08f72a6f77 --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_recommend_app_type.py @@ -0,0 +1,18 @@ +from services.recommend_app.recommend_app_type import RecommendAppType + + +def test_enum_values(): + assert RecommendAppType.REMOTE == "remote" + assert RecommendAppType.BUILDIN == "builtin" + assert RecommendAppType.DATABASE == "db" + + +def test_enum_membership(): + assert "remote" in RecommendAppType.__members__.values() + assert "builtin" in RecommendAppType.__members__.values() + assert "db" in RecommendAppType.__members__.values() + + +def test_enum_is_str(): + for member in RecommendAppType: + assert isinstance(member, str) diff --git a/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py new file mode 100644 index 0000000000..e322fbed4c --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py @@ -0,0 +1,120 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from services.recommend_app.recommend_app_type import RecommendAppType +from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval + + +class TestRemoteRecommendAppRetrieval: + def test_get_type(self): + assert RemoteRecommendAppRetrieval().get_type() == RecommendAppType.REMOTE + + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_recommended_app_detail_from_dify_official", + return_value={"id": "app-1"}, + ) + def test_get_recommend_app_detail_success(self, mock_fetch): + result = RemoteRecommendAppRetrieval().get_recommend_app_detail("app-1") + assert result == {"id": "app-1"} + mock_fetch.assert_called_once_with("app-1") + + @patch( + "services.recommend_app.remote.remote_retrieval" + ".BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin", + return_value={"id": "fallback"}, + ) + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_recommended_app_detail_from_dify_official", + side_effect=ConnectionError("timeout"), + ) + def test_get_recommend_app_detail_falls_back_on_error(self, mock_fetch, mock_builtin): + result = RemoteRecommendAppRetrieval().get_recommend_app_detail("app-1") + assert result == {"id": "fallback"} + mock_builtin.assert_called_once_with("app-1") + + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_recommended_apps_from_dify_official", + return_value={"recommended_apps": [], "categories": []}, + ) + def test_get_recommended_apps_success(self, mock_fetch): + result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") + assert result == {"recommended_apps": [], "categories": []} + + @patch( + "services.recommend_app.remote.remote_retrieval" + ".BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin", + return_value={"recommended_apps": [{"id": "builtin"}]}, + ) + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_recommended_apps_from_dify_official", + side_effect=ValueError("server error"), + ) + def test_get_recommended_apps_falls_back_on_error(self, mock_fetch, mock_builtin): + result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") + assert result == {"recommended_apps": [{"id": "builtin"}]} + + +class TestFetchFromDifyOfficial: + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_detail_returns_json_on_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = {"id": "app-1", "name": "Test"} + mock_get.return_value = mock_response + + result = RemoteRecommendAppRetrieval.fetch_recommended_app_detail_from_dify_official("app-1") + + assert result == {"id": "app-1", "name": "Test"} + mock_get.assert_called_once() + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_detail_returns_none_on_non_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_get.return_value = MagicMock(status_code=404) + + result = RemoteRecommendAppRetrieval.fetch_recommended_app_detail_from_dify_official("app-1") + + assert result is None + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_apps_returns_sorted_categories_on_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = { + "recommended_apps": [], + "categories": ["writing", "agent", "chat"], + } + mock_get.return_value = mock_response + + result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") + + assert result["categories"] == ["agent", "chat", "writing"] + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_apps_raises_on_non_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_get.return_value = MagicMock(status_code=500) + + with pytest.raises(ValueError, match="fetch recommended apps failed"): + RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_apps_without_categories_key(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = {"recommended_apps": []} + mock_get.return_value = mock_response + + result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") + + assert "categories" not in result diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py index 6097bcbd61..4bfdba87a0 100644 --- a/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py +++ b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py @@ -13,6 +13,7 @@ from datetime import datetime from unittest.mock import Mock, create_autospec, patch import pytest +from sqlalchemy import Column, Integer, MetaData, String, Table from libs.archive_storage import ArchiveStorageNotConfiguredError from models.trigger import WorkflowTriggerLog @@ -127,10 +128,41 @@ class WorkflowRunRestoreTestDataFactory: if tables_data is None: tables_data = { - "workflow_runs": [{"id": "run-123", "tenant_id": "tenant-123"}], + "workflow_runs": [ + { + "id": "run-123", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "type": "workflow", + "triggered_from": "app", + "version": "1", + "status": "succeeded", + "created_by_role": "account", + "created_by": "user-123", + } + ], "workflow_app_logs": [ - {"id": "log-1", "workflow_run_id": "run-123"}, - {"id": "log-2", "workflow_run_id": "run-123"}, + { + "id": "log-1", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "workflow_run_id": "run-123", + "created_from": "app", + "created_by_role": "account", + "created_by": "user-123", + }, + { + "id": "log-2", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "workflow_run_id": "run-123", + "created_from": "app", + "created_by_role": "account", + "created_by": "user-123", + }, ], } @@ -406,14 +438,48 @@ class TestGetModelColumnInfo: assert "created_by" in column_names assert "status" in column_names - # WorkflowRun model has no required columns (all have defaults or are auto-generated) - assert required_columns == set() + # Columns without defaults should be required for restore inserts. + assert { + "tenant_id", + "app_id", + "workflow_id", + "type", + "triggered_from", + "version", + "status", + "created_by_role", + "created_by", + }.issubset(required_columns) + assert "id" not in required_columns + assert "created_at" not in required_columns # Check columns with defaults or server defaults assert "id" in non_nullable_with_default assert "created_at" in non_nullable_with_default assert "elapsed_time" in non_nullable_with_default assert "total_tokens" in non_nullable_with_default + assert "tenant_id" not in non_nullable_with_default + + def test_non_pk_auto_autoincrement_column_is_still_required(self): + """`autoincrement='auto'` should not mark non-PK columns as defaulted.""" + restore = WorkflowRunRestore() + + test_table = Table( + "test_autoincrement", + MetaData(), + Column("id", Integer, primary_key=True, autoincrement=True), + Column("required_field", String(255), nullable=False), + Column("defaulted_field", String(255), nullable=False, default="x"), + ) + + class MockModel: + __table__ = test_table + + _, required_columns, non_nullable_with_default = restore._get_model_column_info(MockModel) + + assert required_columns == {"required_field"} + assert "id" in non_nullable_with_default + assert "defaulted_field" in non_nullable_with_default # --------------------------------------------------------------------------- @@ -465,7 +531,32 @@ class TestRestoreTableRecords: mock_stmt.on_conflict_do_nothing.return_value = mock_stmt mock_pg_insert.return_value = mock_stmt - records = [{"id": "test1", "tenant_id": "tenant-123"}, {"id": "test2", "tenant_id": "tenant-123"}] + records = [ + { + "id": "test1", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "type": "workflow", + "triggered_from": "app", + "version": "1", + "status": "succeeded", + "created_by_role": "account", + "created_by": "user-123", + }, + { + "id": "test2", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "type": "workflow", + "triggered_from": "app", + "version": "1", + "status": "succeeded", + "created_by_role": "account", + "created_by": "user-123", + }, + ] result = restore._restore_table_records(mock_session, "workflow_runs", records, schema_version="1.0") @@ -477,8 +568,7 @@ class TestRestoreTableRecords: restore = WorkflowRunRestore() mock_session = Mock() - # Since WorkflowRun has no required columns, we need to test with a different model - # Let's test with a mock model that has required columns + # Use a dedicated mock model to isolate required-column validation behavior. mock_model = Mock() # Mock a required column @@ -965,6 +1055,13 @@ class TestIntegration: "id": "run-123", "tenant_id": "tenant-123", "app_id": "app-123", + "workflow_id": "workflow-123", + "type": "workflow", + "triggered_from": "app", + "version": "1", + "status": "succeeded", + "created_by_role": "account", + "created_by": "user-123", "created_at": "2024-01-01T12:00:00", } ], diff --git a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py index 3c0db51cd2..1926cb133a 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py @@ -258,38 +258,38 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) - return q msg_session_1 = MagicMock() - msg_session_1.query.side_effect = ( - lambda model: make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock() + msg_session_1.query.side_effect = lambda model: ( + make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock() ) msg_session_1.commit.return_value = None msg_session_2 = MagicMock() - msg_session_2.query.side_effect = ( - lambda model: make_query_with_batches([[]]) if model == service_module.Message else MagicMock() + msg_session_2.query.side_effect = lambda model: ( + make_query_with_batches([[]]) if model == service_module.Message else MagicMock() ) msg_session_2.commit.return_value = None conv_session_1 = MagicMock() - conv_session_1.query.side_effect = ( - lambda model: make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock() + conv_session_1.query.side_effect = lambda model: ( + make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock() ) conv_session_1.commit.return_value = None conv_session_2 = MagicMock() - conv_session_2.query.side_effect = ( - lambda model: make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock() + conv_session_2.query.side_effect = lambda model: ( + make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock() ) conv_session_2.commit.return_value = None wal_session_1 = MagicMock() - wal_session_1.query.side_effect = ( - lambda model: make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock() + wal_session_1.query.side_effect = lambda model: ( + make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock() ) wal_session_1.commit.return_value = None wal_session_2 = MagicMock() - wal_session_2.query.side_effect = ( - lambda model: make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock() + wal_session_2.query.side_effect = lambda model: ( + make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock() ) wal_session_2.commit.return_value = None diff --git a/api/tests/unit_tests/services/test_ops_service.py b/api/tests/unit_tests/services/test_ops_service.py new file mode 100644 index 0000000000..ab7b473790 --- /dev/null +++ b/api/tests/unit_tests/services/test_ops_service.py @@ -0,0 +1,381 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.config_entity import TracingProviderEnum +from models.model import App, TraceAppConfig +from services.ops_service import OpsService + + +class TestOpsService: + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = OpsService.get_tracing_app_config("app_id", "arize") + + # Assert + assert result is None + mock_db.session.query.assert_called_with(TraceAppConfig) + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, None] + + # Act + result = OpsService.get_tracing_app_config("app_id", "arize") + + # Assert + assert result is None + assert mock_db.session.query.call_count == 2 + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = None + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + # Act & Assert + with pytest.raises(ValueError, match="Tracing config cannot be None."): + OpsService.get_tracing_app_config("app_id", "arize") + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + @pytest.mark.parametrize( + ("provider", "default_url"), + [ + ("arize", "https://app.arize.com/"), + ("phoenix", "https://app.phoenix.arize.com/projects/"), + ("langsmith", "https://smith.langchain.com/"), + ("opik", "https://www.comet.com/opik/"), + ("weave", "https://wandb.ai/"), + ("aliyun", "https://arms.console.aliyun.com/"), + ("tencent", "https://console.cloud.tencent.com/apm"), + ("mlflow", "http://localhost:5000/"), + ("databricks", "https://www.databricks.com/"), + ], + ) + def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} + mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") + mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") + + # Act + result = OpsService.get_tracing_app_config("app_id", provider) + + # Assert + assert result["tracing_config"]["project_url"] == default_url + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + @pytest.mark.parametrize( + "provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"] + ) + def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} + mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url" + + # Act + result = OpsService.get_tracing_app_config("app_id", provider) + + # Assert + assert result["tracing_config"]["project_url"] == "success_url" + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" + + # Act + result = OpsService.get_tracing_app_config("app_id", "langfuse") + + # Assert + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key" + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") + + # Act + result = OpsService.get_tracing_app_config("app_id", "langfuse") + + # Assert + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/" + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): + # Act + result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {}) + + # Assert + assert result == {"error": "Invalid tracing provider: invalid_provider"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.LANGFUSE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"}) + + # Assert + assert result == {"error": "Invalid Credentials"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + @pytest.mark.parametrize( + ("provider", "config"), + [ + (TracingProviderEnum.ARIZE, {}), + (TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}), + (TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}), + (TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}), + ], + ) + def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config): + # Arrange + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") + mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig) + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, config) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.LANGFUSE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + + # Act + result = OpsService.create_tracing_app_config( + "app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"} + ) + + # Assert + assert result == {"result": "success"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig) + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, None] + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + + # Act + # 'project' is in other_keys for Arize + # provide an empty string for the project in the tracing_config + # create_tracing_app_config will replace it with the default from the model + result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""}) + + # Assert + assert result == {"result": "success"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url" + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"} + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {}) + + # Assert + assert result == {"result": "success"} + mock_db.session.add.assert_called() + mock_db.session.commit.assert_called() + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): + # Act & Assert + with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"): + OpsService.update_tracing_app_config("app_id", "invalid_provider", {}) + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = OpsService.update_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + current_config = MagicMock(spec=TraceAppConfig) + mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, None] + + # Act + result = OpsService.update_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + current_config = MagicMock(spec=TraceAppConfig) + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app] + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + + # Act & Assert + with pytest.raises(ValueError, match="Invalid Credentials"): + OpsService.update_tracing_app_config("app_id", provider, {}) + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + current_config = MagicMock(spec=TraceAppConfig) + current_config.to_dict.return_value = {"some": "data"} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app] + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + + # Act + result = OpsService.update_tracing_app_config("app_id", provider, {}) + + # Assert + assert result == {"some": "data"} + mock_db.session.commit.assert_called_once() + + @patch("services.ops_service.db") + def test_delete_tracing_app_config_no_config(self, mock_db): + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = OpsService.delete_tracing_app_config("app_id", "arize") + + # Assert + assert result is None + + @patch("services.ops_service.db") + def test_delete_tracing_app_config_success(self, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + mock_db.session.query.return_value.where.return_value.first.return_value = trace_config + + # Act + result = OpsService.delete_tracing_app_config("app_id", "arize") + + # Assert + assert result is True + mock_db.session.delete.assert_called_with(trace_config) + mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py new file mode 100644 index 0000000000..c7e1fed21f --- /dev/null +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -0,0 +1,1329 @@ +"""Unit tests for services.summary_index_service.""" + +from __future__ import annotations + +import sys +from dataclasses import dataclass +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import services.summary_index_service as summary_module +from services.summary_index_service import SummaryIndexService + + +@dataclass(frozen=True) +class _SessionContext: + session: MagicMock + + def __enter__(self) -> MagicMock: + return self.session + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + +def _dataset(*, indexing_technique: str = "high_quality") -> MagicMock: + dataset = MagicMock(name="dataset") + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = indexing_technique + dataset.embedding_model_provider = "openai" + dataset.embedding_model = "text-embedding" + return dataset + + +def _segment(*, has_document: bool = True) -> MagicMock: + segment = MagicMock(name="segment") + segment.id = "seg-1" + segment.document_id = "doc-1" + segment.dataset_id = "dataset-1" + segment.content = "hello world" + segment.enabled = True + segment.status = "completed" + segment.position = 1 + if has_document: + doc = MagicMock(name="document") + doc.doc_language = "en" + doc.doc_form = "text_model" + segment.document = doc + else: + segment.document = None + return segment + + +def _summary_record(*, summary_content: str = "summary", node_id: str | None = None) -> MagicMock: + record = MagicMock(spec=summary_module.DocumentSegmentSummary, name="summary_record") + record.id = "sum-1" + record.dataset_id = "dataset-1" + record.document_id = "doc-1" + record.chunk_id = "seg-1" + record.summary_content = summary_content + record.summary_index_node_id = node_id + record.summary_index_node_hash = None + record.tokens = None + record.status = "generating" + record.error = None + record.enabled = True + record.created_at = datetime(2024, 1, 1, tzinfo=UTC) + record.updated_at = datetime(2024, 1, 1, tzinfo=UTC) + record.disabled_at = None + record.disabled_by = None + return record + + +def test_generate_summary_for_segment_passes_document_language(monkeypatch: pytest.MonkeyPatch) -> None: + usage = MagicMock() + usage.total_tokens = 10 + usage.prompt_tokens = 3 + usage.completion_tokens = 7 + + paragraph_module = SimpleNamespace( + ParagraphIndexProcessor=SimpleNamespace(generate_summary=MagicMock(return_value=("sum", usage))) + ) + monkeypatch.setitem( + sys.modules, + "core.rag.index_processor.processor.paragraph_index_processor", + paragraph_module, + ) + + segment = _segment(has_document=True) + dataset = _dataset() + + content, got_usage = SummaryIndexService.generate_summary_for_segment(segment, dataset, {"a": 1}) + assert content == "sum" + assert got_usage is usage + + paragraph_module.ParagraphIndexProcessor.generate_summary.assert_called_once() + _, kwargs = paragraph_module.ParagraphIndexProcessor.generate_summary.call_args + assert kwargs["document_language"] == "en" + + +def test_generate_summary_for_segment_raises_when_empty(monkeypatch: pytest.MonkeyPatch) -> None: + paragraph_module = SimpleNamespace( + ParagraphIndexProcessor=SimpleNamespace(generate_summary=MagicMock(return_value=("", MagicMock()))) + ) + monkeypatch.setitem( + sys.modules, + "core.rag.index_processor.processor.paragraph_index_processor", + paragraph_module, + ) + + with pytest.raises(ValueError, match="Generated summary is empty"): + SummaryIndexService.generate_summary_for_segment(_segment(), _dataset(), {"a": 1}) + + +def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytest.MonkeyPatch) -> None: + existing = _summary_record(summary_content="old", node_id="n1") + existing.enabled = False + existing.disabled_at = datetime(2024, 1, 1) + existing.disabled_by = "u" + + session = MagicMock(name="session") + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = existing + session.query.return_value = query + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + segment = _segment() + dataset = _dataset() + + result = SummaryIndexService.create_summary_record(segment, dataset, "new", status="generating") + assert result is existing + assert existing.summary_content == "new" + assert existing.status == "generating" + assert existing.enabled is True + assert existing.disabled_at is None + assert existing.disabled_by is None + assert existing.error is None + session.add.assert_called_once_with(existing) + session.flush.assert_called_once() + + +def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> None: + session = MagicMock(name="session") + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + record = SummaryIndexService.create_summary_record(_segment(), _dataset(), "new", status="generating") + assert record.dataset_id == "dataset-1" + assert record.chunk_id == "seg-1" + assert record.summary_content == "new" + assert record.enabled is True + session.add.assert_called_once() + session.flush.assert_called_once() + + +def test_vectorize_summary_skips_non_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: + vector_cls = MagicMock() + monkeypatch.setattr(summary_module, "Vector", vector_cls) + SummaryIndexService.vectorize_summary(_summary_record(), _segment(), _dataset(indexing_technique="economy")) + vector_cls.assert_not_called() + + +def test_vectorize_summary_raises_for_blank_content() -> None: + with pytest.raises(ValueError, match="Summary content is empty"): + SummaryIndexService.vectorize_summary(_summary_record(summary_content=" "), _segment(), _dataset()) + + +def test_vectorize_summary_retries_connection_errors_then_succeeds(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.return_value = [5] + model_manager = MagicMock() + model_manager.get_model_instance.return_value = embedding_model + monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + + vector_instance = MagicMock() + vector_instance.add_texts.side_effect = [RuntimeError("connection timeout"), None] + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + + session = MagicMock(name="provided_session") + merged = _summary_record(summary_content="sum") + session.merge.return_value = merged + monkeypatch.setattr(summary_module.time, "sleep", MagicMock()) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=session) + + assert vector_instance.add_texts.call_count == 2 + summary_module.time.sleep.assert_called_once() # type: ignore[attr-defined] + session.flush.assert_called_once() + assert summary.status == "completed" + assert summary.summary_index_node_id == "uuid-1" + assert summary.summary_index_node_hash == "hash-1" + assert summary.tokens == 5 + + +def test_vectorize_summary_without_session_creates_record_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id="old-node") + + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + + # Force deletion branch to run and swallow delete failures. + vector_for_delete = MagicMock() + vector_for_delete.delete_by_ids.side_effect = RuntimeError("delete failed") + vector_for_add = MagicMock() + vector_for_add.add_texts.return_value = None + vector_cls = MagicMock(side_effect=[vector_for_delete, vector_for_add]) + monkeypatch.setattr(summary_module, "Vector", vector_cls) + + model_manager = MagicMock() + model_manager.get_model_instance.side_effect = RuntimeError("no model") + monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + + # New session used after vectorization succeeds (record not found by id nor chunk_id). + session = MagicMock(name="session") + q1 = MagicMock() + q1.filter_by.return_value = q1 + q1.first.side_effect = [None, None] + session.query.return_value = q1 + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + # One context for success path, no error handler session. + create_session_mock.assert_called() + session.add.assert_called() + session.commit.assert_called_once() + assert summary.status == "completed" + assert summary.summary_index_node_id == "old-node" # reused + + +def test_vectorize_summary_final_failure_updates_error_status(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr(summary_module.time, "sleep", MagicMock()) + + vector_instance = MagicMock() + vector_instance.add_texts.side_effect = RuntimeError("boom") + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + + # error_session should find record and commit status update + error_session = MagicMock(name="error_session") + q = MagicMock() + q.filter_by.return_value = q + q.first.return_value = summary + error_session.query.return_value = q + + create_session_mock = MagicMock(return_value=_SessionContext(error_session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + with pytest.raises(RuntimeError, match="boom"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + assert summary.status == "error" + assert "Vectorization failed" in (summary.error or "") + error_session.commit.assert_called_once() + + +def test_batch_create_summary_records_no_segments_noop(monkeypatch: pytest.MonkeyPatch) -> None: + create_session_mock = MagicMock() + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + SummaryIndexService.batch_create_summary_records([], _dataset()) + create_session_mock.assert_not_called() + + +def test_batch_create_summary_records_creates_and_updates(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + s1 = _segment() + s2 = _segment() + s2.id = "seg-2" + s2.document_id = "doc-2" + + existing = _summary_record() + existing.chunk_id = "seg-2" + existing.enabled = False + + session = MagicMock() + query = MagicMock() + query.filter.return_value = query + query.all.return_value = [existing] + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.batch_create_summary_records([s1, s2], dataset, status="not_started") + session.commit.assert_called_once() + assert existing.enabled is True + + +def test_update_summary_record_error_updates_when_exists(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.update_summary_record_error(segment, dataset, "err") + assert record.status == "error" + assert record.error == "err" + session.commit.assert_called_once() + + +def test_generate_and_vectorize_summary_success(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + SummaryIndexService, "generate_summary_for_segment", MagicMock(return_value=("sum", MagicMock(total_tokens=0))) + ) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + + out = SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) + assert out is record + session.refresh.assert_called_once_with(record) + session.commit.assert_called() + + +def test_generate_and_vectorize_summary_vectorize_failure_sets_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + SummaryIndexService, "generate_summary_for_segment", MagicMock(return_value=("sum", MagicMock(total_tokens=0))) + ) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(side_effect=RuntimeError("boom"))) + + with pytest.raises(RuntimeError, match="boom"): + SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) + assert record.status == "error" + # Outer exception handler overwrites the error with the raw exception message. + assert record.error == "boom" + + +def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + + vector_instance = MagicMock() + vector_instance.add_texts.return_value = None + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + existing = _summary_record(summary_content="old", node_id="old-node") + existing.id = "other-id" + session = MagicMock(name="session") + q = MagicMock() + q.filter_by.return_value = q + q.first.side_effect = [None, existing] # miss by id, hit by chunk_id + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + session.commit.assert_called_once() + assert existing.summary_index_node_id == "uuid-1" + + +def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr( + summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) + ) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + existing = _summary_record(summary_content="old", node_id="old-node") + session = MagicMock(name="session") + q = MagicMock() + q.filter_by.return_value = q + q.first.return_value = existing # hit by id + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + session.commit.assert_called_once() + assert existing.summary_index_node_hash == "hash-1" + + +def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr( + summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) + ) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + class _BadContext: + def __enter__(self): + return None + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + error_session = MagicMock() + q = MagicMock() + q.filter_by.return_value = q + q.first.return_value = summary + error_session.query.return_value = q + + create_session_mock = MagicMock(side_effect=[_BadContext(), _SessionContext(error_session)]) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + with pytest.raises(RuntimeError, match="Session should not be None"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + +def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr( + summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) + ) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + session = MagicMock() + q = MagicMock() + q.filter_by.return_value = q + q.first.side_effect = [None, None] # miss by id and chunk_id + session.query.return_value = q + + error_session = MagicMock() + eq = MagicMock() + eq.filter_by.return_value = eq + eq.first.return_value = summary + error_session.query.return_value = eq + + create_session_mock = MagicMock(side_effect=[_SessionContext(session), _SessionContext(error_session)]) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + # Force the created record to be None so the "should not be None" guard triggers. + monkeypatch.setattr(summary_module, "DocumentSegmentSummary", MagicMock(return_value=None)) + + with pytest.raises(RuntimeError, match="summary_record_in_session should not be None"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + +def test_vectorize_summary_error_handler_tries_chunk_id_lookup_and_can_warn_not_found( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr(summary_module.time, "sleep", MagicMock()) + monkeypatch.setattr( + summary_module, + "Vector", + MagicMock(return_value=MagicMock(add_texts=MagicMock(side_effect=RuntimeError("boom")))), + ) + + error_session = MagicMock(name="error_session") + q = MagicMock() + q.filter_by.return_value = q + q.first.side_effect = [None, None] # not found by id, not found by chunk_id + error_session.query.return_value = q + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(error_session))), + ) + + with pytest.raises(RuntimeError, match="boom"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + # No record -> no commit in error session. + error_session.commit.assert_not_called() + + +def test_update_summary_record_error_warns_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + SummaryIndexService.update_summary_record_error(segment, dataset, "err") + logger_mock.warning.assert_called_once() + + +def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + usage = MagicMock(total_tokens=4, prompt_tokens=1, completion_tokens=3) + monkeypatch.setattr(SummaryIndexService, "generate_summary_for_segment", MagicMock(return_value=("sum", usage))) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + result = SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) + assert result.status in {"generating", "completed"} + logger_mock.info.assert_called() + + +def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset(indexing_technique="economy") + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] + + dataset = _dataset() + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": False}) == [] + + document.doc_form = "qa_model" + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] + + +def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + + seg1 = _segment() + seg2 = _segment() + seg2.id = "seg-2" + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [seg1, seg2] + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr(SummaryIndexService, "batch_create_summary_records", MagicMock()) + monkeypatch.setattr( + SummaryIndexService, + "generate_and_vectorize_summary", + MagicMock(side_effect=[MagicMock(), RuntimeError("boom")]), + ) + update_err_mock = MagicMock() + monkeypatch.setattr(SummaryIndexService, "update_summary_record_error", update_err_mock) + + records = SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) + assert len(records) == 1 + update_err_mock.assert_called_once() + + +def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] + + +def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chunks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + seg = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [seg] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + monkeypatch.setattr(SummaryIndexService, "batch_create_summary_records", MagicMock()) + monkeypatch.setattr(SummaryIndexService, "generate_and_vectorize_summary", MagicMock(return_value=MagicMock())) + + SummaryIndexService.generate_summaries_for_document( + dataset, + document, + {"enable": True}, + segment_ids=[seg.id], + only_parent_chunks=True, + ) + query.filter.assert_called() + + +def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + summary1 = _summary_record(summary_content="s", node_id="n1") + summary2 = _summary_record(summary_content="s", node_id=None) + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [summary1, summary2] + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + summary_module, + "Vector", + MagicMock(return_value=MagicMock(delete_by_ids=MagicMock(side_effect=RuntimeError("boom")))), + ) + monkeypatch.setitem( + sys.modules, "libs.datetime_utils", SimpleNamespace(naive_utc_now=MagicMock(return_value=datetime(2024, 1, 1))) + ) + + SummaryIndexService.disable_summaries_for_segments(dataset, segment_ids=["seg-1"], disabled_by="u") + assert summary1.enabled is False + assert summary1.disabled_by == "u" + session.commit.assert_called_once() + + +def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setitem( + sys.modules, "libs.datetime_utils", SimpleNamespace(naive_utc_now=MagicMock(return_value=datetime(2024, 1, 1))) + ) + SummaryIndexService.disable_summaries_for_segments(dataset) + session.commit.assert_not_called() + + +def test_enable_summaries_for_segments_skips_non_high_quality() -> None: + SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique="economy")) + + +def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + summary = _summary_record(summary_content="sum", node_id="n1") + summary.enabled = False + + segment = _segment() + segment.id = summary.chunk_id + segment.enabled = True + segment.status = "completed" + + session = MagicMock() + summary_query = MagicMock() + summary_query.filter_by.return_value = summary_query + summary_query.filter.return_value = summary_query + summary_query.all.return_value = [summary] + + seg_query = MagicMock() + seg_query.filter_by.return_value = seg_query + seg_query.first.return_value = segment + + def query_side_effect(model: object) -> MagicMock: + if model is summary_module.DocumentSegmentSummary: + return summary_query + return seg_query + + session.query.side_effect = query_side_effect + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + vec_mock = MagicMock() + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vec_mock) + + SummaryIndexService.enable_summaries_for_segments(dataset, segment_ids=[summary.chunk_id]) + vec_mock.assert_called_once() + assert summary.enabled is True + session.commit.assert_called_once() + + +def test_enable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + SummaryIndexService.enable_summaries_for_segments(dataset) + session.commit.assert_not_called() + + +def test_enable_summaries_for_segments_skips_segment_or_content_and_handles_vectorize_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + summary1 = _summary_record(summary_content="sum", node_id="n1") + summary1.enabled = False + summary2 = _summary_record(summary_content="", node_id="n2") + summary2.enabled = False + summary3 = _summary_record(summary_content="sum3", node_id="n3") + summary3.enabled = False + + bad_segment = _segment() + bad_segment.enabled = False + bad_segment.status = "completed" + + good_segment = _segment() + good_segment.enabled = True + good_segment.status = "completed" + + session = MagicMock() + summary_query = MagicMock() + summary_query.filter_by.return_value = summary_query + summary_query.filter.return_value = summary_query + summary_query.all.return_value = [summary1, summary2, summary3] + + seg_query = MagicMock() + seg_query.filter_by.return_value = seg_query + seg_query.first.side_effect = [bad_segment, good_segment, good_segment] + + def query_side_effect(model: object) -> MagicMock: + if model is summary_module.DocumentSegmentSummary: + return summary_query + return seg_query + + session.query.side_effect = query_side_effect + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(side_effect=RuntimeError("boom"))) + + SummaryIndexService.enable_summaries_for_segments(dataset) + logger_mock.exception.assert_called_once() + session.commit.assert_called_once() + + +def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + summary = _summary_record(summary_content="sum", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [summary] + session.query.return_value = query + + vector_instance = MagicMock() + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids=[summary.chunk_id]) + vector_instance.delete_by_ids.assert_called_once_with(["n1"]) + session.delete.assert_called_once_with(summary) + session.commit.assert_called_once() + + +def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + SummaryIndexService.delete_summaries_for_segments(dataset) + session.commit.assert_not_called() + + +def test_update_summary_for_segment_skip_conditions() -> None: + assert ( + SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None + ) + seg = _segment(has_document=True) + seg.document.doc_form = "qa_model" + assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None + + +def test_update_summary_for_segment_empty_content_deletes_existing(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + vector_instance = MagicMock() + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.update_summary_for_segment(segment, dataset, " ") is None + vector_instance.delete_by_ids.assert_called_once_with(["n1"]) + session.delete.assert_called_once_with(record) + session.commit.assert_called_once() + + +def test_update_summary_for_segment_empty_content_delete_vector_warns(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + vector_instance = MagicMock() + vector_instance.delete_by_ids.side_effect = RuntimeError("boom") + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + assert SummaryIndexService.update_summary_for_segment(segment, dataset, "") is None + logger_mock.warning.assert_called() + + +def test_update_summary_for_segment_empty_content_no_record_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.update_summary_for_segment(segment, dataset, " ") is None + + +def test_update_summary_for_segment_updates_existing_and_vectorizes(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + vector_instance = MagicMock() + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + vectorize_mock = MagicMock() + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vectorize_mock) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new summary") + assert out is record + vectorize_mock.assert_called_once() + session.refresh.assert_called_once_with(record) + session.commit.assert_called() + + +def test_update_summary_for_segment_existing_vector_delete_warns(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + vector_instance = MagicMock() + vector_instance.delete_by_ids.side_effect = RuntimeError("boom") + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + logger_mock.warning.assert_called() + + +def test_update_summary_for_segment_existing_vectorize_failure_returns_error_record( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(side_effect=RuntimeError("boom"))) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert out is record + assert out.status == "error" + assert "Vectorization failed" in (out.error or "") + + +def test_update_summary_for_segment_new_record_success(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + created = _summary_record(summary_content="new", node_id=None) + monkeypatch.setattr(SummaryIndexService, "create_summary_record", MagicMock(return_value=created)) + session.merge.return_value = created + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert out is created + session.refresh.assert_called() + session.commit.assert_called() + + +def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + session.flush.side_effect = RuntimeError("flush boom") + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + with pytest.raises(RuntimeError, match="flush boom"): + SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert record.status == "error" + assert record.error == "flush boom" + session.commit.assert_called() + + +def test_get_segment_summary_and_document_summaries(monkeypatch: pytest.MonkeyPatch) -> None: + record = _summary_record(summary_content="sum", node_id="n1") + session = MagicMock() + + q1 = MagicMock() + q1.where.return_value = q1 + q1.first.return_value = record + + q2 = MagicMock() + q2.filter.return_value = q2 + q2.all.return_value = [record] + + def query_side_effect(model: object) -> MagicMock: + if model is summary_module.DocumentSegmentSummary: + # first call used by get_segment_summary, second by get_document_summaries + if not hasattr(query_side_effect, "_called"): + query_side_effect._called = True # type: ignore[attr-defined] + return q1 + return q2 + return MagicMock() + + session.query.side_effect = query_side_effect + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.get_segment_summary("seg-1", "dataset-1") is record + assert SummaryIndexService.get_document_summaries("doc-1", "dataset-1", segment_ids=["seg-1"]) == [record] + + +def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> None: + record1 = _summary_record() + record1.chunk_id = "seg-1" + record2 = _summary_record() + record2.chunk_id = "seg-2" + session = MagicMock() + q = MagicMock() + q.where.return_value = q + q.all.return_value = [record1, record2] + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + out = SummaryIndexService.get_segments_summaries(["seg-1", "seg-2"], "dataset-1") + assert set(out.keys()) == {"seg-1", "seg-2"} + + +def test_get_document_summary_index_status_no_segments_returns_none(monkeypatch: pytest.MonkeyPatch) -> None: + session = MagicMock() + q = MagicMock() + q.where.return_value = q + q.all.return_value = [] + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") is None + + +def test_get_documents_summary_index_status_empty_input(monkeypatch: pytest.MonkeyPatch) -> None: + assert SummaryIndexService.get_documents_summary_index_status([], "dataset-1", "tenant-1") == {} + + +def test_get_documents_summary_index_status_no_pending_sets_none(monkeypatch: pytest.MonkeyPatch) -> None: + session = MagicMock() + q = MagicMock() + q.where.return_value = q + q.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")] + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + SummaryIndexService, + "get_segments_summaries", + MagicMock(return_value={"seg-1": SimpleNamespace(status="completed")}), + ) + result = SummaryIndexService.get_documents_summary_index_status(["doc-1"], "dataset-1", "tenant-1") + assert result["doc-1"] is None + + +def test_update_summary_for_segment_creates_new_and_vectorize_fails_returns_error_record( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + created = _summary_record(summary_content="new", node_id=None) + monkeypatch.setattr(SummaryIndexService, "create_summary_record", MagicMock(return_value=created)) + session.merge.return_value = created + + vectorize_mock = MagicMock(side_effect=RuntimeError("boom")) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vectorize_mock) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert out.status == "error" + assert "Vectorization failed" in (out.error or "") + + +def test_get_segments_summaries_empty_list() -> None: + assert SummaryIndexService.get_segments_summaries([], "dataset-1") == {} + + +def test_get_document_summary_index_status_and_documents_status(monkeypatch: pytest.MonkeyPatch) -> None: + seg_row = SimpleNamespace(id="seg-1", document_id="doc-1") + session = MagicMock() + query = MagicMock() + query.where.return_value = query + query.all.return_value = [SimpleNamespace(id="seg-1")] + session.query.return_value = query + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + monkeypatch.setattr( + SummaryIndexService, + "get_segments_summaries", + MagicMock(return_value={"seg-1": SimpleNamespace(status="generating")}), + ) + assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") == "SUMMARIZING" + + # Multiple docs + query2 = MagicMock() + query2.where.return_value = query2 + query2.all.return_value = [seg_row] + session2 = MagicMock() + session2.query.return_value = query2 + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session2))), + ) + monkeypatch.setattr( + SummaryIndexService, + "get_segments_summaries", + MagicMock(return_value={"seg-1": SimpleNamespace(status="not_started")}), + ) + result = SummaryIndexService.get_documents_summary_index_status(["doc-1", "doc-2"], "dataset-1", "tenant-1") + assert result["doc-1"] == "SUMMARIZING" + assert result["doc-2"] is None + + +def test_get_document_summary_status_detail_counts_and_previews(monkeypatch: pytest.MonkeyPatch) -> None: + segment1 = _segment() + segment1.id = "seg-1" + segment1.position = 1 + segment2 = _segment() + segment2.id = "seg-2" + segment2.position = 2 + + summary1 = _summary_record(summary_content="x" * 150, node_id="n1") + summary1.chunk_id = "seg-1" + summary1.status = "completed" + summary1.error = None + summary1.created_at = datetime(2024, 1, 1, tzinfo=UTC) + summary1.updated_at = datetime(2024, 1, 2, tzinfo=UTC) + + segment_service = SimpleNamespace(get_segments_by_document_and_dataset=MagicMock(return_value=[segment1, segment2])) + monkeypatch.setitem(sys.modules, "services.dataset_service", SimpleNamespace(SegmentService=segment_service)) + + monkeypatch.setattr(SummaryIndexService, "get_document_summaries", MagicMock(return_value=[summary1])) + + detail = SummaryIndexService.get_document_summary_status_detail("doc-1", "dataset-1") + assert detail["total_segments"] == 2 + assert detail["summary_status"]["completed"] == 1 + assert detail["summary_status"]["not_started"] == 1 + assert detail["summaries"][0]["summary_preview"].endswith("...") + assert detail["summaries"][1]["status"] == "not_started" diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py new file mode 100644 index 0000000000..7b0103a2a1 --- /dev/null +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -0,0 +1,704 @@ +"""Unit tests for `api/services/vector_service.py`.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest + +import services.vector_service as vector_service_module +from services.vector_service import VectorService + + +@dataclass(frozen=True) +class _UploadFileStub: + id: str + name: str + + +@dataclass(frozen=True) +class _ChildDocStub: + page_content: str + metadata: dict[str, Any] + + +@dataclass +class _ParentDocStub: + children: list[_ChildDocStub] + + +def _make_dataset( + *, + indexing_technique: str = "high_quality", + doc_form: str = "text_model", + tenant_id: str = "tenant-1", + dataset_id: str = "dataset-1", + is_multimodal: bool = False, + embedding_model_provider: str | None = "openai", + embedding_model: str = "text-embedding", +) -> MagicMock: + dataset = MagicMock(name="dataset") + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.doc_form = doc_form + dataset.indexing_technique = indexing_technique + dataset.is_multimodal = is_multimodal + dataset.embedding_model_provider = embedding_model_provider + dataset.embedding_model = embedding_model + return dataset + + +def _make_segment( + *, + segment_id: str = "seg-1", + tenant_id: str = "tenant-1", + dataset_id: str = "dataset-1", + document_id: str = "doc-1", + content: str = "hello", + index_node_id: str = "node-1", + index_node_hash: str = "hash-1", + attachments: list[dict[str, str]] | None = None, +) -> MagicMock: + segment = MagicMock(name="segment") + segment.id = segment_id + segment.tenant_id = tenant_id + segment.dataset_id = dataset_id + segment.document_id = document_id + segment.content = content + segment.index_node_id = index_node_id + segment.index_node_hash = index_node_hash + segment.attachments = attachments or [] + return segment + + +def _mock_db_session_for_update_multimodel(*, upload_files: list[_UploadFileStub] | None) -> MagicMock: + session = MagicMock(name="session") + + binding_query = MagicMock(name="binding_query") + binding_query.where.return_value = binding_query + binding_query.delete.return_value = 1 + + upload_query = MagicMock(name="upload_query") + upload_query.where.return_value = upload_query + upload_query.all.return_value = upload_files or [] + + def query_side_effect(model: object) -> MagicMock: + if model is vector_service_module.SegmentAttachmentBinding: + return binding_query + if model is vector_service_module.UploadFile: + return upload_query + return MagicMock(name=f"query({model})") + + session.query.side_effect = query_side_effect + db_mock = MagicMock(name="db") + db_mock.session = session + return db_mock + + +def test_create_segments_vector_regular_indexing_loads_documents_and_keywords(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(is_multimodal=False) + segment = _make_segment() + + index_processor = MagicMock(name="index_processor") + factory_instance = MagicMock(name="IndexProcessorFactory-instance") + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + + index_processor.load.assert_called_once() + args, kwargs = index_processor.load.call_args + assert args[0] == dataset + assert len(args[1]) == 1 + assert args[2] is None + assert kwargs["with_keywords"] is True + assert kwargs["keywords_list"] == [["k1"]] + + +def test_create_segments_vector_regular_indexing_loads_multimodal_documents(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(is_multimodal=True) + segment = _make_segment( + attachments=[ + {"id": "img-1", "name": "a.png"}, + {"id": "img-2", "name": "b.png"}, + ] + ) + + index_processor = MagicMock(name="index_processor") + factory_instance = MagicMock(name="IndexProcessorFactory-instance") + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + + assert index_processor.load.call_count == 2 + first_args, first_kwargs = index_processor.load.call_args_list[0] + assert first_args[0] == dataset + assert len(first_args[1]) == 1 + assert first_kwargs["with_keywords"] is True + + second_args, second_kwargs = index_processor.load.call_args_list[1] + assert second_args[0] == dataset + assert second_args[1] == [] + assert len(second_args[2]) == 2 + assert second_kwargs["with_keywords"] is False + + +def test_create_segments_vector_with_no_segments_does_not_load(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset() + index_processor = MagicMock(name="index_processor") + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector(None, [], dataset, "text_model") + index_processor.load.assert_not_called() + + +def _mock_parent_child_queries( + *, + dataset_document: object | None, + processing_rule: object | None, +) -> MagicMock: + session = MagicMock(name="session") + + doc_query = MagicMock(name="doc_query") + doc_query.filter_by.return_value = doc_query + doc_query.first.return_value = dataset_document + + rule_query = MagicMock(name="rule_query") + rule_query.where.return_value = rule_query + rule_query.first.return_value = processing_rule + + def query_side_effect(model: object) -> MagicMock: + if model is vector_service_module.DatasetDocument: + return doc_query + if model is vector_service_module.DatasetProcessRule: + return rule_query + return MagicMock(name=f"query({model})") + + session.query.side_effect = query_side_effect + db_mock = MagicMock(name="db") + db_mock.session = session + return db_mock + + +def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_explicit_model( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset( + doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, + embedding_model_provider="openai", + indexing_technique="high_quality", + ) + segment = _make_segment() + + dataset_document = MagicMock(name="dataset_document") + dataset_document.id = segment.document_id + dataset_document.dataset_process_rule_id = "rule-1" + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock(name="processing_rule") + processing_rule.to_dict.return_value = {"rules": {}} + + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule), + ) + + embedding_model_instance = MagicMock(name="embedding_model_instance") + model_manager_instance = MagicMock(name="model_manager_instance") + model_manager_instance.get_model_instance.return_value = embedding_model_instance + monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + + generate_child_chunks_mock = MagicMock() + monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) + + index_processor = MagicMock() + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + model_manager_instance.get_model_instance.assert_called_once() + generate_child_chunks_mock.assert_called_once_with( + segment, dataset_document, dataset, embedding_model_instance, processing_rule, False + ) + index_processor.load.assert_not_called() + + +def test_create_segments_vector_parent_child_uses_default_embedding_model_when_provider_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset( + doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, + embedding_model_provider=None, + indexing_technique="high_quality", + ) + segment = _make_segment() + + dataset_document = MagicMock() + dataset_document.dataset_process_rule_id = "rule-1" + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock() + processing_rule.to_dict.return_value = {"rules": {}} + + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule), + ) + + embedding_model_instance = MagicMock() + model_manager_instance = MagicMock() + model_manager_instance.get_default_model_instance.return_value = embedding_model_instance + monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + + generate_child_chunks_mock = MagicMock() + monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) + + index_processor = MagicMock() + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + model_manager_instance.get_default_model_instance.assert_called_once() + generate_child_chunks_mock.assert_called_once() + + +def test_create_segments_vector_parent_child_missing_document_logs_warning_and_continues( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX) + segment = _make_segment() + + processing_rule = MagicMock() + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=None, processing_rule=processing_rule), + ) + + logger_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "logger", logger_mock) + + index_processor = MagicMock() + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + logger_mock.warning.assert_called_once() + index_processor.load.assert_not_called() + + +def test_create_segments_vector_parent_child_missing_processing_rule_raises(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX) + segment = _make_segment() + + dataset_document = MagicMock() + dataset_document.dataset_process_rule_id = "rule-1" + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=None), + ) + + with pytest.raises(ValueError, match="No processing rule found"): + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + +def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset( + doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, + indexing_technique="economy", + ) + segment = _make_segment() + dataset_document = MagicMock() + dataset_document.dataset_process_rule_id = "rule-1" + processing_rule = MagicMock() + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule), + ) + + with pytest.raises(ValueError, match="not high quality"): + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + +def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality") + segment = _make_segment() + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.update_segment_vector(["k"], segment, dataset) + + vector_instance.delete_by_ids.assert_called_once_with([segment.index_node_id]) + vector_instance.add_texts.assert_called_once() + add_args, add_kwargs = vector_instance.add_texts.call_args + assert len(add_args[0]) == 1 + assert add_kwargs["duplicate_check"] is True + + +def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + segment = _make_segment() + + keyword_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Keyword", MagicMock(return_value=keyword_instance)) + + VectorService.update_segment_vector(["a", "b"], segment, dataset) + + keyword_instance.delete_by_ids.assert_called_once_with([segment.index_node_id]) + keyword_instance.add_texts.assert_called_once() + args, kwargs = keyword_instance.add_texts.call_args + assert len(args[0]) == 1 + assert kwargs["keywords_list"] == [["a", "b"]] + + +def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + segment = _make_segment() + + keyword_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Keyword", MagicMock(return_value=keyword_instance)) + + VectorService.update_segment_vector(None, segment, dataset) + keyword_instance.add_texts.assert_called_once() + _, kwargs = keyword_instance.add_texts.call_args + assert "keywords_list" not in kwargs + + +def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(doc_form="text_model", tenant_id="tenant-1", dataset_id="dataset-1") + segment = _make_segment(segment_id="seg-1") + + dataset_document = MagicMock() + dataset_document.id = segment.document_id + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock() + processing_rule.to_dict.return_value = {"rules": {}} + + child1 = _ChildDocStub(page_content="c1", metadata={"doc_id": "c1-id", "doc_hash": "c1-h"}) + child2 = _ChildDocStub(page_content="c2", metadata={"doc_id": "c2-id", "doc_hash": "c2-h"}) + transformed = [_ParentDocStub(children=[child1, child2])] + + index_processor = MagicMock() + index_processor.transform.return_value = transformed + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + child_chunk_ctor = MagicMock(side_effect=lambda **kwargs: kwargs) + monkeypatch.setattr(vector_service_module, "ChildChunk", child_chunk_ctor) + + db_mock = MagicMock() + db_mock.session.add = MagicMock() + db_mock.session.commit = MagicMock() + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.generate_child_chunks( + segment=segment, + dataset_document=dataset_document, + dataset=dataset, + embedding_model_instance=MagicMock(), + processing_rule=processing_rule, + regenerate=True, + ) + + index_processor.clean.assert_called_once() + _, transform_kwargs = index_processor.transform.call_args + assert transform_kwargs["process_rule"]["rules"]["parent_mode"] == vector_service_module.ParentMode.FULL_DOC + index_processor.load.assert_called_once() + assert db_mock.session.add.call_count == 2 + db_mock.session.commit.assert_called_once() + + +def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(doc_form="text_model") + segment = _make_segment() + dataset_document = MagicMock() + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock() + processing_rule.to_dict.return_value = {"rules": {}} + + index_processor = MagicMock() + index_processor.transform.return_value = [_ParentDocStub(children=[])] + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + db_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.generate_child_chunks( + segment=segment, + dataset_document=dataset_document, + dataset=dataset, + embedding_model_instance=MagicMock(), + processing_rule=processing_rule, + regenerate=False, + ) + + index_processor.load.assert_not_called() + db_mock.session.add.assert_not_called() + db_mock.session.commit.assert_called_once() + + +def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality") + child_chunk = MagicMock() + child_chunk.content = "child" + child_chunk.index_node_id = "id" + child_chunk.index_node_hash = "h" + child_chunk.document_id = "doc-1" + child_chunk.dataset_id = "dataset-1" + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.create_child_chunk_vector(child_chunk, dataset) + vector_instance.add_texts.assert_called_once() + + +def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + vector_cls = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + + child_chunk = MagicMock() + child_chunk.content = "child" + child_chunk.index_node_id = "id" + child_chunk.index_node_hash = "h" + child_chunk.document_id = "doc-1" + child_chunk.dataset_id = "dataset-1" + + VectorService.create_child_chunk_vector(child_chunk, dataset) + vector_cls.assert_not_called() + + +def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality") + + new_chunk = MagicMock() + new_chunk.content = "n" + new_chunk.index_node_id = "nid" + new_chunk.index_node_hash = "nh" + new_chunk.document_id = "d" + new_chunk.dataset_id = "ds" + + upd_chunk = MagicMock() + upd_chunk.content = "u" + upd_chunk.index_node_id = "uid" + upd_chunk.index_node_hash = "uh" + upd_chunk.document_id = "d" + upd_chunk.dataset_id = "ds" + + del_chunk = MagicMock() + del_chunk.index_node_id = "did" + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.update_child_chunk_vector([new_chunk], [upd_chunk], [del_chunk], dataset) + + vector_instance.delete_by_ids.assert_called_once_with(["uid", "did"]) + vector_instance.add_texts.assert_called_once() + docs = vector_instance.add_texts.call_args.args[0] + assert len(docs) == 2 + + +def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + vector_cls = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + VectorService.update_child_chunk_vector([], [], [], dataset) + vector_cls.assert_not_called() + + +def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset() + child_chunk = MagicMock() + child_chunk.index_node_id = "cid" + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.delete_child_chunk_vector(child_chunk, dataset) + vector_instance.delete_by_ids.assert_called_once_with(["cid"]) + + +# --------------------------------------------------------------------------- +# update_multimodel_vector (missing coverage in previous suites) +# --------------------------------------------------------------------------- + + +def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy", is_multimodal=True) + segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}]) + + vector_cls = MagicMock() + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["a"], dataset=dataset) + vector_cls.assert_not_called() + db_mock.session.query.assert_not_called() + + +def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}]) + + vector_cls = MagicMock() + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["b", "a"], dataset=dataset) + vector_cls.assert_not_called() + db_mock.session.query.assert_not_called() + + +def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}]) + + vector_instance = MagicMock(name="vector_instance") + vector_cls = MagicMock(return_value=vector_instance) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=[], dataset=dataset) + + vector_cls.assert_called_once_with(dataset=dataset) + vector_instance.delete_by_ids.assert_called_once_with(["old-1", "old-2"]) + db_mock.session.query.assert_called_once_with(vector_service_module.SegmentAttachmentBinding) + db_mock.session.commit.assert_called_once() + db_mock.session.add_all.assert_not_called() + vector_instance.add_texts.assert_not_called() + + +def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["new-1"], dataset=dataset) + + db_mock.session.commit.assert_called_once() + db_mock.session.add_all.assert_not_called() + vector_instance.add_texts.assert_not_called() + + +def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")]) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + binding_ctor = MagicMock(side_effect=lambda **kwargs: kwargs) + monkeypatch.setattr(vector_service_module, "SegmentAttachmentBinding", binding_ctor) + + logger_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "logger", logger_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1", "missing"], dataset=dataset) + + logger_mock.warning.assert_called_once() + db_mock.session.add_all.assert_called_once() + bindings = db_mock.session.add_all.call_args.args[0] + assert len(bindings) == 1 + assert bindings[0]["attachment_id"] == "file-1" + + vector_instance.add_texts.assert_called_once() + documents = vector_instance.add_texts.call_args.args[0] + assert len(documents) == 1 + assert documents[0].page_content == "img.png" + assert documents[0].metadata["doc_id"] == "file-1" + db_mock.session.commit.assert_called_once() + + +def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False) + segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")]) + monkeypatch.setattr(vector_service_module, "db", db_mock) + monkeypatch.setattr( + vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs) + ) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset) + + vector_instance.delete_by_ids.assert_not_called() + vector_instance.add_texts.assert_not_called() + db_mock.session.add_all.assert_called_once() + db_mock.session.commit.assert_called_once() + + +def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")]) + db_mock.session.commit.side_effect = RuntimeError("boom") + monkeypatch.setattr(vector_service_module, "db", db_mock) + monkeypatch.setattr( + vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs) + ) + + logger_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "logger", logger_mock) + + with pytest.raises(RuntimeError, match="boom"): + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset) + + logger_mock.exception.assert_called_once() + db_mock.session.rollback.assert_called_once() diff --git a/api/tests/unit_tests/services/test_website_service.py b/api/tests/unit_tests/services/test_website_service.py new file mode 100644 index 0000000000..e2775ce90c --- /dev/null +++ b/api/tests/unit_tests/services/test_website_service.py @@ -0,0 +1,718 @@ +"""Unit tests for services.website_service. + +Focuses on provider dispatching, argument validation, and provider-specific branches +without making any real network/storage/redis calls. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +import services.website_service as website_service_module +from services.website_service import ( + CrawlOptions, + WebsiteCrawlApiRequest, + WebsiteCrawlStatusApiRequest, + WebsiteService, +) + + +@dataclass(frozen=True) +class _DummyHttpxResponse: + payload: dict[str, Any] + + def json(self) -> dict[str, Any]: + return self.payload + + +@pytest.fixture(autouse=True) +def stub_current_user(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + website_service_module, + "current_user", + type("User", (), {"current_tenant_id": "tenant-1"})(), + ) + + +def test_crawl_options_include_exclude_paths() -> None: + options = CrawlOptions(includes="a,b", excludes="x,y") + assert options.get_include_paths() == ["a", "b"] + assert options.get_exclude_paths() == ["x", "y"] + + empty = CrawlOptions(includes=None, excludes=None) + assert empty.get_include_paths() == [] + assert empty.get_exclude_paths() == [] + + +def test_website_crawl_api_request_from_args_valid_and_to_crawl_request() -> None: + args = { + "provider": "firecrawl", + "url": "https://example.com", + "options": { + "limit": 2, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "a,b", + "excludes": "x", + "prompt": "hi", + "max_depth": 3, + "use_sitemap": False, + }, + } + + api_req = WebsiteCrawlApiRequest.from_args(args) + crawl_req = api_req.to_crawl_request() + + assert crawl_req.provider == "firecrawl" + assert crawl_req.url == "https://example.com" + assert crawl_req.options.limit == 2 + assert crawl_req.options.crawl_sub_pages is True + assert crawl_req.options.only_main_content is True + assert crawl_req.options.get_include_paths() == ["a", "b"] + assert crawl_req.options.get_exclude_paths() == ["x"] + assert crawl_req.options.prompt == "hi" + assert crawl_req.options.max_depth == 3 + assert crawl_req.options.use_sitemap is False + + +@pytest.mark.parametrize( + ("args", "missing_msg"), + [ + ({}, "Provider is required"), + ({"provider": "firecrawl"}, "URL is required"), + ({"provider": "firecrawl", "url": "https://example.com"}, "Options are required"), + ], +) +def test_website_crawl_api_request_from_args_requires_fields(args: dict, missing_msg: str) -> None: + with pytest.raises(ValueError, match=missing_msg): + WebsiteCrawlApiRequest.from_args(args) + + +def test_website_crawl_status_api_request_from_args_requires_fields() -> None: + with pytest.raises(ValueError, match="Provider is required"): + WebsiteCrawlStatusApiRequest.from_args({}, job_id="job-1") + + with pytest.raises(ValueError, match="Job ID is required"): + WebsiteCrawlStatusApiRequest.from_args({"provider": "firecrawl"}, job_id="") + + req = WebsiteCrawlStatusApiRequest.from_args({"provider": "firecrawl"}, job_id="job-1") + assert req.provider == "firecrawl" + assert req.job_id == "job-1" + + +def test_get_credentials_and_config_selects_plugin_id_and_key_firecrawl(monkeypatch: pytest.MonkeyPatch) -> None: + service_instance = MagicMock(name="DatasourceProviderService-instance") + service_instance.get_datasource_credentials.return_value = {"firecrawl_api_key": "k", "base_url": "b"} + monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance)) + + api_key, config = WebsiteService._get_credentials_and_config("tenant-1", "firecrawl") + assert api_key == "k" + assert config["base_url"] == "b" + + service_instance.get_datasource_credentials.assert_called_once_with( + tenant_id="tenant-1", + provider="firecrawl", + plugin_id="langgenius/firecrawl_datasource", + ) + + +@pytest.mark.parametrize( + ("provider", "plugin_id"), + [ + ("watercrawl", "watercrawl/watercrawl_datasource"), + ("jinareader", "langgenius/jina_datasource"), + ], +) +def test_get_credentials_and_config_selects_plugin_id_and_key_api_key( + monkeypatch: pytest.MonkeyPatch, provider: str, plugin_id: str +) -> None: + service_instance = MagicMock(name="DatasourceProviderService-instance") + service_instance.get_datasource_credentials.return_value = {"api_key": "enc-key", "base_url": "b"} + monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance)) + + api_key, config = WebsiteService._get_credentials_and_config("tenant-1", provider) + assert api_key == "enc-key" + assert config["base_url"] == "b" + + service_instance.get_datasource_credentials.assert_called_once_with( + tenant_id="tenant-1", + provider=provider, + plugin_id=plugin_id, + ) + + +def test_get_credentials_and_config_rejects_invalid_provider() -> None: + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService._get_credentials_and_config("tenant-1", "unknown") + + +def test_get_credentials_and_config_hits_unreachable_guard_branch(monkeypatch: pytest.MonkeyPatch) -> None: + class FlakyProvider: + def __init__(self) -> None: + self._eq_calls = 0 + + def __hash__(self) -> int: + return 1 + + def __eq__(self, other: object) -> bool: + if other == "firecrawl": + self._eq_calls += 1 + return self._eq_calls == 1 + return False + + def __repr__(self) -> str: + return "FlakyProvider()" + + service_instance = MagicMock(name="DatasourceProviderService-instance") + service_instance.get_datasource_credentials.return_value = {"firecrawl_api_key": "k"} + monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance)) + + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService._get_credentials_and_config("tenant-1", FlakyProvider()) # type: ignore[arg-type] + + +def test_get_decrypted_api_key_requires_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(website_service_module.encrypter, "decrypt_token", MagicMock()) + with pytest.raises(ValueError, match="API key not found in configuration"): + WebsiteService._get_decrypted_api_key("tenant-1", {}) + + +def test_get_decrypted_api_key_decrypts(monkeypatch: pytest.MonkeyPatch) -> None: + decrypt_mock = MagicMock(return_value="plain") + monkeypatch.setattr(website_service_module.encrypter, "decrypt_token", decrypt_mock) + + assert WebsiteService._get_decrypted_api_key("tenant-1", {"api_key": "enc"}) == "plain" + decrypt_mock.assert_called_once_with(tenant_id="tenant-1", token="enc") + + +def test_document_create_args_validate_wraps_error_message() -> None: + with pytest.raises(ValueError, match=r"^Invalid arguments: Provider is required$"): + WebsiteService.document_create_args_validate({}) + + +def test_crawl_url_dispatches_by_provider(monkeypatch: pytest.MonkeyPatch) -> None: + api_request = WebsiteCrawlApiRequest(provider="firecrawl", url="https://example.com", options={"limit": 1}) + crawl_request = api_request.to_crawl_request() + + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + firecrawl_mock = MagicMock(return_value={"status": "active", "job_id": "j1"}) + monkeypatch.setattr(WebsiteService, "_crawl_with_firecrawl", firecrawl_mock) + + result = WebsiteService.crawl_url(api_request) + + assert result == {"status": "active", "job_id": "j1"} + firecrawl_mock.assert_called_once() + assert firecrawl_mock.call_args.kwargs["request"] == crawl_request + + +@pytest.mark.parametrize( + ("provider", "method_name"), + [ + ("watercrawl", "_crawl_with_watercrawl"), + ("jinareader", "_crawl_with_jinareader"), + ], +) +def test_crawl_url_dispatches_other_providers(monkeypatch: pytest.MonkeyPatch, provider: str, method_name: str) -> None: + api_request = WebsiteCrawlApiRequest(provider=provider, url="https://example.com", options={"limit": 1}) + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + + impl_mock = MagicMock(return_value={"status": "active"}) + monkeypatch.setattr(WebsiteService, method_name, impl_mock) + + assert WebsiteService.crawl_url(api_request) == {"status": "active"} + impl_mock.assert_called_once() + + +def test_crawl_url_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None: + api_request = WebsiteCrawlApiRequest(provider="bad", url="https://example.com", options={"limit": 1}) + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {}))) + + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.crawl_url(api_request) + + +def test_crawl_with_firecrawl_builds_params_single_page_and_sets_redis(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock(name="FirecrawlApp-instance") + firecrawl_instance.crawl_url.return_value = "job-1" + firecrawl_cls = MagicMock(return_value=firecrawl_instance) + monkeypatch.setattr(website_service_module, "FirecrawlApp", firecrawl_cls) + + redis_mock = MagicMock() + monkeypatch.setattr(website_service_module, "redis_client", redis_mock) + + fixed_now = datetime(2024, 1, 1, tzinfo=UTC) + with patch.object(website_service_module.datetime, "datetime") as datetime_mock: + datetime_mock.now.return_value = fixed_now + + req = WebsiteCrawlApiRequest( + provider="firecrawl", url="https://example.com", options={"limit": 5} + ).to_crawl_request() + req.options.crawl_sub_pages = False + req.options.only_main_content = True + + result = WebsiteService._crawl_with_firecrawl(request=req, api_key="k", config={"base_url": "b"}) + + assert result == {"status": "active", "job_id": "job-1"} + + firecrawl_cls.assert_called_once_with(api_key="k", base_url="b") + firecrawl_instance.crawl_url.assert_called_once() + _, params = firecrawl_instance.crawl_url.call_args.args + assert params["limit"] == 1 + assert params["includePaths"] == [] + assert params["excludePaths"] == [] + assert params["scrapeOptions"] == {"onlyMainContent": True} + + redis_mock.setex.assert_called_once() + key, ttl, value = redis_mock.setex.call_args.args + assert key == "website_crawl_job-1" + assert ttl == 3600 + assert float(value) == pytest.approx(fixed_now.timestamp(), rel=0, abs=1e-6) + + +def test_crawl_with_firecrawl_builds_params_multi_page_including_prompt(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock(name="FirecrawlApp-instance") + firecrawl_instance.crawl_url.return_value = "job-2" + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + monkeypatch.setattr(website_service_module, "redis_client", MagicMock()) + + req = WebsiteCrawlApiRequest( + provider="firecrawl", + url="https://example.com", + options={ + "crawl_sub_pages": True, + "limit": 3, + "only_main_content": False, + "includes": "a,b", + "excludes": "x", + "prompt": "use this", + }, + ).to_crawl_request() + + WebsiteService._crawl_with_firecrawl(request=req, api_key="k", config={"base_url": None}) + _, params = firecrawl_instance.crawl_url.call_args.args + assert params["includePaths"] == ["a", "b"] + assert params["excludePaths"] == ["x"] + assert params["limit"] == 3 + assert params["scrapeOptions"] == {"onlyMainContent": False} + assert params["prompt"] == "use this" + + +def test_crawl_with_watercrawl_passes_options_dict(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.crawl_url.return_value = {"status": "active", "job_id": "w1"} + provider_cls = MagicMock(return_value=provider_instance) + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", provider_cls) + + req = WebsiteCrawlApiRequest( + provider="watercrawl", + url="https://example.com", + options={ + "limit": 2, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "a", + "excludes": None, + "max_depth": 5, + "use_sitemap": False, + }, + ).to_crawl_request() + + result = WebsiteService._crawl_with_watercrawl(request=req, api_key="k", config={"base_url": "b"}) + assert result == {"status": "active", "job_id": "w1"} + + provider_cls.assert_called_once_with(api_key="k", base_url="b") + provider_instance.crawl_url.assert_called_once_with( + url="https://example.com", + options={ + "limit": 2, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "a", + "excludes": None, + "max_depth": 5, + "use_sitemap": False, + }, + ) + + +def test_crawl_with_jinareader_single_page_success(monkeypatch: pytest.MonkeyPatch) -> None: + get_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"title": "t"}})) + monkeypatch.setattr(website_service_module.httpx, "get", get_mock) + + req = WebsiteCrawlApiRequest( + provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False} + ).to_crawl_request() + req.options.crawl_sub_pages = False + + result = WebsiteService._crawl_with_jinareader(request=req, api_key="k") + assert result == {"status": "active", "data": {"title": "t"}} + get_mock.assert_called_once() + + +def test_crawl_with_jinareader_single_page_failure(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500}))) + req = WebsiteCrawlApiRequest( + provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False} + ).to_crawl_request() + req.options.crawl_sub_pages = False + + with pytest.raises(ValueError, match="Failed to crawl:"): + WebsiteService._crawl_with_jinareader(request=req, api_key="k") + + +def test_crawl_with_jinareader_multi_page_success(monkeypatch: pytest.MonkeyPatch) -> None: + post_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"taskId": "t1"}})) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + req = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={"crawl_sub_pages": True, "limit": 5, "use_sitemap": True}, + ).to_crawl_request() + req.options.crawl_sub_pages = True + + result = WebsiteService._crawl_with_jinareader(request=req, api_key="k") + assert result == {"status": "active", "job_id": "t1"} + post_mock.assert_called_once() + + +def test_crawl_with_jinareader_multi_page_failure(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + website_service_module.httpx, "post", MagicMock(return_value=_DummyHttpxResponse({"code": 400})) + ) + req = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={"crawl_sub_pages": True, "limit": 2, "use_sitemap": False}, + ).to_crawl_request() + req.options.crawl_sub_pages = True + + with pytest.raises(ValueError, match="Failed to crawl$"): + WebsiteService._crawl_with_jinareader(request=req, api_key="k") + + +def test_get_crawl_status_dispatches(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + firecrawl_status = MagicMock(return_value={"status": "active"}) + monkeypatch.setattr(WebsiteService, "_get_firecrawl_status", firecrawl_status) + + result = WebsiteService.get_crawl_status("job-1", "firecrawl") + assert result == {"status": "active"} + firecrawl_status.assert_called_once_with("job-1", "k", {"base_url": "b"}) + + watercrawl_status = MagicMock(return_value={"status": "active", "job_id": "w"}) + monkeypatch.setattr(WebsiteService, "_get_watercrawl_status", watercrawl_status) + assert WebsiteService.get_crawl_status("job-2", "watercrawl") == {"status": "active", "job_id": "w"} + watercrawl_status.assert_called_once_with("job-2", "k", {"base_url": "b"}) + + jinareader_status = MagicMock(return_value={"status": "active", "job_id": "j"}) + monkeypatch.setattr(WebsiteService, "_get_jinareader_status", jinareader_status) + assert WebsiteService.get_crawl_status("job-3", "jinareader") == {"status": "active", "job_id": "j"} + jinareader_status.assert_called_once_with("job-3", "k") + + +def test_get_crawl_status_typed_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {}))) + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_crawl_status_typed(WebsiteCrawlStatusApiRequest(provider="bad", job_id="j")) + + +def test_get_firecrawl_status_adds_time_consuming_when_completed_and_cached(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "total": 2, "current": 2, "data": []} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + redis_mock = MagicMock() + redis_mock.get.return_value = b"100.0" + monkeypatch.setattr(website_service_module, "redis_client", redis_mock) + + with patch.object(website_service_module.datetime, "datetime") as datetime_mock: + datetime_mock.now.return_value = datetime.fromtimestamp(105.0, tz=UTC) + result = WebsiteService._get_firecrawl_status(job_id="job-1", api_key="k", config={"base_url": "b"}) + + assert result["status"] == "completed" + assert result["time_consuming"] == "5.00" + redis_mock.delete.assert_called_once_with("website_crawl_job-1") + + +def test_get_firecrawl_status_completed_without_cache_does_not_add_time(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "completed"} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + redis_mock = MagicMock() + redis_mock.get.return_value = None + monkeypatch.setattr(website_service_module, "redis_client", redis_mock) + + result = WebsiteService._get_firecrawl_status(job_id="job-1", api_key="k", config={"base_url": None}) + assert result["status"] == "completed" + assert "time_consuming" not in result + redis_mock.delete.assert_not_called() + + +def test_get_watercrawl_status_delegates(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.get_crawl_status.return_value = {"status": "active", "job_id": "w1"} + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance)) + + assert WebsiteService._get_watercrawl_status("job-1", "k", {"base_url": "b"}) == { + "status": "active", + "job_id": "w1", + } + provider_instance.get_crawl_status.assert_called_once_with("job-1") + + +def test_get_jinareader_status_active(monkeypatch: pytest.MonkeyPatch) -> None: + post_mock = MagicMock( + return_value=_DummyHttpxResponse( + { + "data": { + "status": "active", + "urls": ["a", "b"], + "processed": {"a": {}}, + "failed": {"b": {}}, + "duration": 3000, + } + } + ) + ) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + result = WebsiteService._get_jinareader_status("job-1", "k") + assert result["status"] == "active" + assert result["total"] == 2 + assert result["current"] == 2 + assert result["time_consuming"] == 3.0 + assert result["data"] == [] + post_mock.assert_called_once() + + +def test_get_jinareader_status_completed_formats_processed_items(monkeypatch: pytest.MonkeyPatch) -> None: + status_payload = { + "data": { + "status": "completed", + "urls": ["u1"], + "processed": {"u1": {}}, + "failed": {}, + "duration": 1000, + } + } + processed_payload = { + "data": { + "processed": { + "u1": { + "data": { + "title": "t", + "url": "u1", + "description": "d", + "content": "md", + } + } + } + } + } + post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + result = WebsiteService._get_jinareader_status("job-1", "k") + assert result["status"] == "completed" + assert result["data"] == [{"title": "t", "source_url": "u1", "description": "d", "markdown": "md"}] + assert post_mock.call_count == 2 + + +def test_get_crawl_url_data_dispatches_invalid_provider() -> None: + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_crawl_url_data("job-1", "bad", "https://example.com", "tenant-1") + + +def test_get_crawl_url_data_hits_invalid_provider_branch_when_credentials_stubbed( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {}))) + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_crawl_url_data("job-1", object(), "u", "tenant-1") # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("provider", "method_name"), + [ + ("firecrawl", "_get_firecrawl_url_data"), + ("watercrawl", "_get_watercrawl_url_data"), + ("jinareader", "_get_jinareader_url_data"), + ], +) +def test_get_crawl_url_data_dispatches(monkeypatch: pytest.MonkeyPatch, provider: str, method_name: str) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + impl_mock = MagicMock(return_value={"ok": True}) + monkeypatch.setattr(WebsiteService, method_name, impl_mock) + + result = WebsiteService.get_crawl_url_data("job-1", provider, "u", "tenant-1") + assert result == {"ok": True} + impl_mock.assert_called_once() + + +def test_get_firecrawl_url_data_reads_from_storage_when_present(monkeypatch: pytest.MonkeyPatch) -> None: + stored_list = [{"source_url": "https://example.com", "title": "t"}] + stored = json.dumps(stored_list).encode("utf-8") + + storage_mock = MagicMock() + storage_mock.exists.return_value = True + storage_mock.load_once.return_value = stored + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock()) + + result = WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": "b"}) + assert result == {"source_url": "https://example.com", "title": "t"} + assert result is not stored_list[0] + + +def test_get_firecrawl_url_data_returns_none_when_storage_empty(monkeypatch: pytest.MonkeyPatch) -> None: + storage_mock = MagicMock() + storage_mock.exists.return_value = True + storage_mock.load_once.return_value = b"" + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + assert WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {}) is None + + +def test_get_firecrawl_url_data_raises_when_job_not_completed(monkeypatch: pytest.MonkeyPatch) -> None: + storage_mock = MagicMock() + storage_mock.exists.return_value = False + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "active"} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + with pytest.raises(ValueError, match="Crawl job is not completed"): + WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": None}) + + +def test_get_firecrawl_url_data_returns_none_when_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + storage_mock = MagicMock() + storage_mock.exists.return_value = False + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "data": [{"source_url": "x"}]} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + assert WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": "b"}) is None + + +def test_get_watercrawl_url_data_delegates(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.get_crawl_url_data.return_value = {"source_url": "u"} + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance)) + + result = WebsiteService._get_watercrawl_url_data("job-1", "u", "k", {"base_url": "b"}) + assert result == {"source_url": "u"} + provider_instance.get_crawl_url_data.assert_called_once_with("job-1", "u") + + +def test_get_jinareader_url_data_without_job_id_success(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + website_service_module.httpx, + "get", + MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"url": "u"}})), + ) + assert WebsiteService._get_jinareader_url_data("", "u", "k") == {"url": "u"} + + +def test_get_jinareader_url_data_without_job_id_failure(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500}))) + with pytest.raises(ValueError, match="Failed to crawl$"): + WebsiteService._get_jinareader_url_data("", "u", "k") + + +def test_get_jinareader_url_data_with_job_id_completed_returns_matching_item(monkeypatch: pytest.MonkeyPatch) -> None: + status_payload = {"data": {"status": "completed", "processed": {"u1": {}}}} + processed_payload = {"data": {"processed": {"u1": {"data": {"url": "u", "title": "t"}}}}} + + post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") == {"url": "u", "title": "t"} + assert post_mock.call_count == 2 + + +def test_get_jinareader_url_data_with_job_id_not_completed_raises(monkeypatch: pytest.MonkeyPatch) -> None: + post_mock = MagicMock(return_value=_DummyHttpxResponse({"data": {"status": "active"}})) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + with pytest.raises(ValueError, match=r"Crawl job is no\s*t completed"): + WebsiteService._get_jinareader_url_data("job-1", "u", "k") + + +def test_get_jinareader_url_data_with_job_id_completed_but_not_found_returns_none( + monkeypatch: pytest.MonkeyPatch, +) -> None: + status_payload = {"data": {"status": "completed", "processed": {"u1": {}}}} + processed_payload = {"data": {"processed": {"u1": {"data": {"url": "other"}}}}} + + post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") is None + + +def test_get_scrape_url_data_dispatches_and_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + + scrape_mock = MagicMock(return_value={"data": "x"}) + monkeypatch.setattr(WebsiteService, "_scrape_with_firecrawl", scrape_mock) + assert WebsiteService.get_scrape_url_data("firecrawl", "u", "tenant-1", True) == {"data": "x"} + scrape_mock.assert_called_once() + + watercrawl_mock = MagicMock(return_value={"data": "y"}) + monkeypatch.setattr(WebsiteService, "_scrape_with_watercrawl", watercrawl_mock) + assert WebsiteService.get_scrape_url_data("watercrawl", "u", "tenant-1", False) == {"data": "y"} + watercrawl_mock.assert_called_once() + + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_scrape_url_data("jinareader", "u", "tenant-1", True) + + +def test_scrape_with_firecrawl_calls_app(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock() + firecrawl_instance.scrape_url.return_value = {"markdown": "m"} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + result = WebsiteService._scrape_with_firecrawl( + request=website_service_module.ScrapeRequest( + provider="firecrawl", + url="u", + tenant_id="tenant-1", + only_main_content=True, + ), + api_key="k", + config={"base_url": "b"}, + ) + assert result == {"markdown": "m"} + firecrawl_instance.scrape_url.assert_called_once_with(url="u", params={"onlyMainContent": True}) + + +def test_scrape_with_watercrawl_calls_provider(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.scrape_url.return_value = {"markdown": "m"} + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance)) + + result = WebsiteService._scrape_with_watercrawl( + request=website_service_module.ScrapeRequest( + provider="watercrawl", + url="u", + tenant_id="tenant-1", + only_main_content=False, + ), + api_key="k", + config={"base_url": "b"}, + ) + assert result == {"markdown": "m"} + provider_instance.scrape_url.assert_called_once_with("u") diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 8820a1acc0..5ce0e6f140 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -1001,12 +1001,12 @@ class TestWorkflowService: Used by the UI to populate the node palette and provide sensible defaults when users add new nodes to their workflow. """ - with patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping: + with patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping: # Mock node class with default config mock_node_class = MagicMock() mock_node_class.get_default_config.return_value = {"type": "llm", "config": {}} - mock_mapping.items.return_value = [(NodeType.LLM, {"latest": mock_node_class})] + mock_mapping.return_value = {NodeType.LLM: {"latest": mock_node_class}} with patch("services.workflow_service.LATEST_VERSION", "latest"): result = workflow_service.get_default_block_configs() @@ -1025,7 +1025,7 @@ class TestWorkflowService: ) with ( - patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping, + patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), patch( "services.workflow_service.build_http_request_config", @@ -1036,10 +1036,10 @@ class TestWorkflowService: mock_http_node_class.get_default_config.return_value = {"type": "http-request", "config": {}} mock_llm_node_class = MagicMock() mock_llm_node_class.get_default_config.return_value = {"type": "llm", "config": {}} - mock_mapping.items.return_value = [ - (NodeType.HTTP_REQUEST, {"latest": mock_http_node_class}), - (NodeType.LLM, {"latest": mock_llm_node_class}), - ] + mock_mapping.return_value = { + NodeType.HTTP_REQUEST: {"latest": mock_http_node_class}, + NodeType.LLM: {"latest": mock_llm_node_class}, + } result = workflow_service.get_default_block_configs() @@ -1060,7 +1060,7 @@ class TestWorkflowService: This includes default values for all required and optional parameters. """ with ( - patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping, + patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), ): # Mock node class with default config @@ -1069,8 +1069,7 @@ class TestWorkflowService: mock_node_class.get_default_config.return_value = mock_config # Create a mock mapping that includes NodeType.LLM - mock_mapping.__contains__.return_value = True - mock_mapping.__getitem__.return_value = {"latest": mock_node_class} + mock_mapping.return_value = {NodeType.LLM: {"latest": mock_node_class}} result = workflow_service.get_default_block_config(NodeType.LLM.value) @@ -1079,9 +1078,8 @@ class TestWorkflowService: def test_get_default_block_config_invalid_node_type(self, workflow_service): """Test get_default_block_config returns empty dict for invalid node type.""" - with patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping: - # Mock mapping to not contain the node type - mock_mapping.__contains__.return_value = False + with patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping: + mock_mapping.return_value = {} # Use a valid NodeType but one that's not in the mapping result = workflow_service.get_default_block_config(NodeType.LLM.value) @@ -1100,7 +1098,7 @@ class TestWorkflowService: ) with ( - patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping, + patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), patch( "services.workflow_service.build_http_request_config", @@ -1110,8 +1108,7 @@ class TestWorkflowService: mock_node_class = MagicMock() expected = {"type": "http-request", "config": {}} mock_node_class.get_default_config.return_value = expected - mock_mapping.__contains__.return_value = True - mock_mapping.__getitem__.return_value = {"latest": mock_node_class} + mock_mapping.return_value = {NodeType.HTTP_REQUEST: {"latest": mock_node_class}} result = workflow_service.get_default_block_config(NodeType.HTTP_REQUEST.value) @@ -1132,15 +1129,14 @@ class TestWorkflowService: ) with ( - patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping, + patch("services.workflow_service.get_workflow_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), patch("services.workflow_service.build_http_request_config") as mock_build_config, ): mock_node_class = MagicMock() expected = {"type": "http-request", "config": {}} mock_node_class.get_default_config.return_value = expected - mock_mapping.__contains__.return_value = True - mock_mapping.__getitem__.return_value = {"latest": mock_node_class} + mock_mapping.return_value = {NodeType.HTTP_REQUEST: {"latest": mock_node_class}} result = workflow_service.get_default_block_config( NodeType.HTTP_REQUEST.value, @@ -1155,8 +1151,8 @@ class TestWorkflowService: def test_get_default_block_config_http_request_malformed_config_raises_value_error(self, workflow_service): with ( patch( - "services.workflow_service.NODE_TYPE_CLASSES_MAPPING", - {NodeType.HTTP_REQUEST: {"latest": HttpRequestNode}}, + "services.workflow_service.get_workflow_node_type_classes_mapping", + return_value={NodeType.HTTP_REQUEST: {"latest": HttpRequestNode}}, ), patch("services.workflow_service.LATEST_VERSION", "latest"), ): diff --git a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py new file mode 100644 index 0000000000..439d203c58 --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py @@ -0,0 +1,455 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +MODULE = "services.tools.builtin_tools_manage_service" + + +def _mock_session(mock_session_cls): + """Helper: set up a Session context manager mock and return the inner session.""" + session = MagicMock() + mock_session_cls.return_value.__enter__ = MagicMock(return_value=session) + mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) + return session + + +class TestDeleteCustomOauthClientParams: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_deletes_and_returns_success(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + + result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google") + + assert result == {"result": "success"} + session.query.return_value.filter_by.return_value.delete.assert_called_once() + session.commit.assert_called_once() + + +class TestListBuiltinToolProviderTools: + @patch(f"{MODULE}.ToolLabelManager") + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.ToolManager") + def test_transforms_each_tool(self, mock_manager, mock_transform, mock_labels): + mock_controller = MagicMock() + mock_controller.get_tools.return_value = [MagicMock(), MagicMock()] + mock_manager.get_builtin_provider.return_value = mock_controller + mock_transform.convert_tool_entity_to_api_entity.return_value = MagicMock() + + result = BuiltinToolManageService.list_builtin_tool_provider_tools("tenant-1", "google") + + assert len(result) == 2 + + @patch(f"{MODULE}.ToolLabelManager") + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.ToolManager") + def test_empty_tools(self, mock_manager, mock_transform, mock_labels): + mock_controller = MagicMock() + mock_controller.get_tools.return_value = [] + mock_manager.get_builtin_provider.return_value = mock_controller + + assert BuiltinToolManageService.list_builtin_tool_provider_tools("t", "p") == [] + + +class TestGetBuiltinToolProviderInfo: + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.BuiltinToolManageService.get_builtin_provider") + @patch(f"{MODULE}.ToolManager") + def test_raises_when_not_found(self, mock_manager, mock_get, mock_transform): + mock_get.return_value = None + + with pytest.raises(ValueError, match="you have not added provider"): + BuiltinToolManageService.get_builtin_tool_provider_info("t", "no") + + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.BuiltinToolManageService.get_builtin_provider") + @patch(f"{MODULE}.ToolManager") + def test_clears_original_credentials(self, mock_manager, mock_get, mock_transform): + mock_get.return_value = MagicMock() + entity = MagicMock() + mock_transform.builtin_provider_to_user_provider.return_value = entity + + result = BuiltinToolManageService.get_builtin_tool_provider_info("t", "google") + + assert result.original_credentials == {} + + +class TestListBuiltinProviderCredentialsSchema: + @patch(f"{MODULE}.ToolManager") + def test_returns_schema(self, mock_manager): + mock_manager.get_builtin_provider.return_value.get_credentials_schema_by_type.return_value = [{"f": "k"}] + + result = BuiltinToolManageService.list_builtin_provider_credentials_schema("g", "api_key", "t") + + assert result == [{"f": "k"}] + + +class TestGetBuiltinToolProviderIcon: + @patch(f"{MODULE}.Path") + @patch(f"{MODULE}.ToolManager") + def test_returns_bytes_and_mime(self, mock_manager, mock_path): + mock_manager.get_hardcoded_provider_icon.return_value = ("/icon.svg", "image/svg+xml") + mock_path.return_value.read_bytes.return_value = b"" + + icon, mime = BuiltinToolManageService.get_builtin_tool_provider_icon("google") + + assert icon == b"" + assert mime == "image/svg+xml" + + +class TestIsOauthSystemClientExists: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_true_when_exists(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = MagicMock() + + assert BuiltinToolManageService.is_oauth_system_client_exists("google") is True + + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_false_when_missing(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = None + + assert BuiltinToolManageService.is_oauth_system_client_exists("google") is False + + +class TestIsOauthCustomClientEnabled: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_true_when_enabled(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = MagicMock(enabled=True) + + assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is True + + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_false_when_none(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = None + + assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is False + + +class TestDeleteBuiltinToolProvider: + @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc): + session = _mock_session(mock_session_cls) + session.query.return_value.where.return_value.first.return_value = None + + with pytest.raises(ValueError, match="you have not added provider"): + BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "id") + + @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc): + session = _mock_session(mock_session_cls) + db_provider = MagicMock() + session.query.return_value.where.return_value.first.return_value = db_provider + mock_cache = MagicMock() + mock_enc.return_value = (MagicMock(), mock_cache) + + result = BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "c") + + assert result == {"result": "success"} + session.delete.assert_called_once_with(db_provider) + session.commit.assert_called_once() + mock_cache.delete.assert_called_once() + + +class TestSetDefaultProvider: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_raises_when_not_found(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = None + + with pytest.raises(ValueError, match="provider not found"): + BuiltinToolManageService.set_default_provider("t", "u", "p", "id") + + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_sets_default_and_clears_old(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + target = MagicMock() + session.query.return_value.filter_by.return_value.first.return_value = target + + result = BuiltinToolManageService.set_default_provider("t", "u", "p", "id") + + assert result == {"result": "success"} + assert target.is_default is True + session.commit.assert_called_once() + + +class TestUpdateBuiltinToolProvider: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.where.return_value.first.return_value = None + + with pytest.raises(ValueError, match="you have not added provider"): + BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c") + + @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") + @patch(f"{MODULE}.CredentialType") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc): + session = _mock_session(mock_session_cls) + db_provider = MagicMock(credential_type="api_key", credentials="{}") + session.query.return_value.where.return_value.first.return_value = db_provider + + mock_cred_instance = MagicMock() + mock_cred_instance.is_editable.return_value = True + mock_cred_instance.is_validate_allowed.return_value = False + mock_cred_type.of.return_value = mock_cred_instance + + mock_controller = MagicMock(need_credentials=True) + mock_tm.get_builtin_provider.return_value = mock_controller + + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"key": "old"} + mock_encrypter.encrypt.return_value = {"key": "new"} + mock_cache = MagicMock() + mock_enc.return_value = (mock_encrypter, mock_cache) + + result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"}) + + assert result == {"result": "success"} + session.commit.assert_called_once() + mock_cache.delete.assert_called_once() + + +class TestGetOauthClientSchema: + @patch(f"{MODULE}.BuiltinToolManageService.get_custom_oauth_client_params", return_value={}) + @patch(f"{MODULE}.BuiltinToolManageService.is_oauth_system_client_exists", return_value=False) + @patch(f"{MODULE}.BuiltinToolManageService.is_oauth_custom_client_enabled", return_value=True) + @patch(f"{MODULE}.dify_config") + @patch(f"{MODULE}.PluginService") + @patch(f"{MODULE}.ToolManager") + def test_returns_schema_dict(self, mock_tm, mock_plugin, mock_config, mock_enabled, mock_sys, mock_params): + mock_config.CONSOLE_API_URL = "https://api.example.com" + mock_controller = MagicMock() + mock_controller.get_oauth_client_schema.return_value = [] + mock_tm.get_builtin_provider.return_value = mock_controller + + result = BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema("t", "google") + + assert "schema" in result + assert result["is_oauth_custom_client_enabled"] is True + assert "redirect_uri" in result + + +class TestGetOauthClient: + @patch(f"{MODULE}.PluginService") + @patch(f"{MODULE}.create_provider_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_user_client_params_when_exists( + self, mock_db, mock_session_cls, mock_tm, mock_create_enc, mock_plugin + ): + session = _mock_session(mock_session_cls) + mock_controller = MagicMock() + mock_controller.get_oauth_client_schema.return_value = [] + mock_tm.get_builtin_provider.return_value = mock_controller + + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"client_id": "id", "client_secret": "secret"} + mock_create_enc.return_value = (mock_encrypter, MagicMock()) + + user_client = MagicMock(oauth_params='{"encrypted": "data"}') + session.query.return_value.filter_by.return_value.first.return_value = user_client + + result = BuiltinToolManageService.get_oauth_client("t", "google") + + assert result == {"client_id": "id", "client_secret": "secret"} + + @patch(f"{MODULE}.decrypt_system_oauth_params", return_value={"sys_key": "sys_val"}) + @patch(f"{MODULE}.PluginService") + @patch(f"{MODULE}.create_provider_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_falls_back_to_system_client( + self, mock_db, mock_session_cls, mock_tm, mock_create_enc, mock_plugin, mock_decrypt + ): + session = _mock_session(mock_session_cls) + mock_controller = MagicMock() + mock_controller.get_oauth_client_schema.return_value = [] + mock_tm.get_builtin_provider.return_value = mock_controller + + mock_create_enc.return_value = (MagicMock(), MagicMock()) + + system_client = MagicMock(encrypted_oauth_params="enc") + session.query.return_value.filter_by.return_value.first.side_effect = [ + None, # user client + system_client, # system client + ] + + result = BuiltinToolManageService.get_oauth_client("t", "google") + + assert result == {"sys_key": "sys_val"} + + +class TestSaveCustomOauthClientParams: + def test_returns_early_when_no_params(self): + result = BuiltinToolManageService.save_custom_oauth_client_params("t", "p") + assert result == {"result": "success"} + + @patch(f"{MODULE}.ToolManager") + def test_raises_when_provider_not_found(self, mock_tm): + mock_tm.get_builtin_provider.return_value = None + + with pytest.raises((ValueError, Exception), match="not found|Provider"): + BuiltinToolManageService.save_custom_oauth_client_params("t", "p", enable_oauth_custom_client=True) + + +class TestGetCustomOauthClientParams: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_empty_when_none(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = None + + result = BuiltinToolManageService.get_custom_oauth_client_params("t", "p") + + assert result == {} + + +class TestGetBuiltinToolProviderCredentialInfo: + @patch(f"{MODULE}.BuiltinToolManageService.is_oauth_custom_client_enabled", return_value=False) + @patch(f"{MODULE}.BuiltinToolManageService.get_builtin_tool_provider_credentials", return_value=[]) + @patch(f"{MODULE}.ToolManager") + def test_returns_credential_info(self, mock_tm, mock_creds, mock_oauth): + mock_tm.get_builtin_provider.return_value.get_supported_credential_types.return_value = ["api-key"] + + result = BuiltinToolManageService.get_builtin_tool_provider_credential_info("t", "google") + + assert result.credentials == [] + assert result.supported_credential_types == ["api-key"] + assert result.is_oauth_custom_client_enabled is False + + +class TestGetBuiltinToolProviderCredentials: + @patch(f"{MODULE}.db") + def test_returns_empty_when_no_providers(self, mock_db): + mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None) + mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False) + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] + + result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google") + + assert result == [] + + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.db") + def test_returns_credential_entities(self, mock_db, mock_tm, mock_enc, mock_transform): + mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None) + mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False) + + provider = MagicMock(provider="google", is_default=False) + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [provider] + + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"key": "decrypted"} + mock_encrypter.mask_plugin_credentials.return_value = {"key": "***"} + mock_enc.return_value = (mock_encrypter, MagicMock()) + + credential_entity = MagicMock() + mock_transform.convert_builtin_provider_to_credential_entity.return_value = credential_entity + + result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google") + + assert len(result) == 1 + assert result[0] is credential_entity + assert provider.is_default is True + + +class TestGetBuiltinProvider: + @patch(f"{MODULE}.ToolProviderID") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_none_when_not_found(self, mock_db, mock_session_cls, mock_prov_id): + session = _mock_session(mock_session_cls) + mock_prov_id.return_value.provider_name = "google" + mock_prov_id.return_value.organization = "langgenius" + session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + result = BuiltinToolManageService.get_builtin_provider("google", "t") + + assert result is None + + @patch(f"{MODULE}.ToolProviderID") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_provider_for_langgenius_org(self, mock_db, mock_session_cls, mock_prov_id): + session = _mock_session(mock_session_cls) + mock_prov_id.return_value.provider_name = "google" + mock_prov_id.return_value.organization = "langgenius" + db_provider = MagicMock(provider="google") + mock_prov_id_result = MagicMock() + mock_prov_id_result.to_string.return_value = "langgenius/google/google" + + def prov_id_side_effect(name): + m = MagicMock() + m.provider_name = "google" + m.organization = "langgenius" + m.to_string.return_value = "langgenius/google/google" + m.plugin_id = "langgenius/google" + return m + + mock_prov_id.side_effect = prov_id_side_effect + session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider + + result = BuiltinToolManageService.get_builtin_provider("google", "t") + + assert result is db_provider + + @patch(f"{MODULE}.ToolProviderID") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_provider_for_non_langgenius_org(self, mock_db, mock_session_cls, mock_prov_id): + session = _mock_session(mock_session_cls) + + def prov_id_side_effect(name): + m = MagicMock() + m.provider_name = "custom-tool" + m.organization = "third-party" + m.to_string.return_value = "third-party/custom/custom-tool" + m.plugin_id = "third-party/custom" + return m + + mock_prov_id.side_effect = prov_id_side_effect + db_provider = MagicMock(provider="third-party/custom/custom-tool") + session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider + + result = BuiltinToolManageService.get_builtin_provider("third-party/custom/custom-tool", "t") + + assert result is db_provider + + @patch(f"{MODULE}.ToolProviderID") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_falls_back_on_exception(self, mock_db, mock_session_cls, mock_prov_id): + session = _mock_session(mock_session_cls) + mock_prov_id.side_effect = Exception("parse error") + fallback = MagicMock() + session.query.return_value.where.return_value.order_by.return_value.first.return_value = fallback + + result = BuiltinToolManageService.get_builtin_provider("old-provider", "t") + + assert result is fallback diff --git a/api/tests/unit_tests/services/tools/test_tool_labels_service.py b/api/tests/unit_tests/services/tools/test_tool_labels_service.py new file mode 100644 index 0000000000..6acdbb7901 --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_tool_labels_service.py @@ -0,0 +1,21 @@ +from services.tools.tool_labels_service import ToolLabelsService + + +def test_list_tool_labels_returns_default_labels(): + result = ToolLabelsService.list_tool_labels() + assert isinstance(result, list) + assert len(result) > 0 + + +def test_list_tool_labels_items_are_tool_labels(): + from core.tools.entities.tool_entities import ToolLabel + + result = ToolLabelsService.list_tool_labels() + for label in result: + assert isinstance(label, ToolLabel) + + +def test_list_tool_labels_matches_default_values(): + from core.tools.entities.values import default_tool_labels + + assert ToolLabelsService.list_tool_labels() is default_tool_labels diff --git a/api/tests/unit_tests/services/tools/test_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_tools_manage_service.py new file mode 100644 index 0000000000..73ac9a10c6 --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_tools_manage_service.py @@ -0,0 +1,40 @@ +from unittest.mock import MagicMock, patch + +from services.tools.tools_manage_service import ToolCommonService + + +class TestToolCommonService: + @patch("services.tools.tools_manage_service.ToolTransformService") + @patch("services.tools.tools_manage_service.ToolManager") + def test_list_tool_providers_transforms_and_returns(self, mock_manager, mock_transform): + mock_provider1 = MagicMock() + mock_provider1.to_dict.return_value = {"name": "provider1"} + mock_provider2 = MagicMock() + mock_provider2.to_dict.return_value = {"name": "provider2"} + mock_manager.list_providers_from_api.return_value = [mock_provider1, mock_provider2] + + result = ToolCommonService.list_tool_providers("user-1", "tenant-1") + + mock_manager.list_providers_from_api.assert_called_once_with("user-1", "tenant-1", None) + assert mock_transform.repack_provider.call_count == 2 + assert result == [{"name": "provider1"}, {"name": "provider2"}] + + @patch("services.tools.tools_manage_service.ToolTransformService") + @patch("services.tools.tools_manage_service.ToolManager") + def test_list_tool_providers_with_type_filter(self, mock_manager, mock_transform): + mock_manager.list_providers_from_api.return_value = [] + + result = ToolCommonService.list_tool_providers("user-1", "tenant-1", typ="builtin") + + mock_manager.list_providers_from_api.assert_called_once_with("user-1", "tenant-1", "builtin") + assert result == [] + + @patch("services.tools.tools_manage_service.ToolTransformService") + @patch("services.tools.tools_manage_service.ToolManager") + def test_list_tool_providers_empty(self, mock_manager, mock_transform): + mock_manager.list_providers_from_api.return_value = [] + + result = ToolCommonService.list_tool_providers("u", "t") + + assert result == [] + mock_transform.repack_provider.assert_not_called() diff --git a/api/tests/unit_tests/services/workflow/test_queue_dispatcher.py b/api/tests/unit_tests/services/workflow/test_queue_dispatcher.py new file mode 100644 index 0000000000..bbfc1cc294 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_queue_dispatcher.py @@ -0,0 +1,110 @@ +from unittest.mock import patch + +import pytest + +from services.workflow.queue_dispatcher import ( + BaseQueueDispatcher, + ProfessionalQueueDispatcher, + QueueDispatcherManager, + QueuePriority, + SandboxQueueDispatcher, + TeamQueueDispatcher, +) + + +class TestQueuePriority: + def test_priority_values(self): + assert QueuePriority.PROFESSIONAL == "workflow_professional" + assert QueuePriority.TEAM == "workflow_team" + assert QueuePriority.SANDBOX == "workflow_sandbox" + + +class TestDispatchers: + def test_professional_dispatcher(self): + d = ProfessionalQueueDispatcher() + assert d.get_queue_name() == QueuePriority.PROFESSIONAL + assert d.get_priority() == 100 + + def test_team_dispatcher(self): + d = TeamQueueDispatcher() + assert d.get_queue_name() == QueuePriority.TEAM + assert d.get_priority() == 50 + + def test_sandbox_dispatcher(self): + d = SandboxQueueDispatcher() + assert d.get_queue_name() == QueuePriority.SANDBOX + assert d.get_priority() == 10 + + def test_base_dispatcher_is_abstract(self): + with pytest.raises(TypeError): + BaseQueueDispatcher() + + +class TestQueueDispatcherManager: + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_professional_plan(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {"subscription": {"plan": "professional"}} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, ProfessionalQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_team_plan(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {"subscription": {"plan": "team"}} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, TeamQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_sandbox_plan(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {"subscription": {"plan": "sandbox"}} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, SandboxQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_unknown_plan_defaults_to_sandbox(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {"subscription": {"plan": "enterprise"}} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, SandboxQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_service_failure_defaults_to_sandbox(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.side_effect = Exception("billing unavailable") + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, SandboxQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_disabled_defaults_to_team(self, mock_config): + mock_config.BILLING_ENABLED = False + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, TeamQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_missing_subscription_key_defaults_to_sandbox(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, SandboxQueueDispatcher) diff --git a/api/tests/unit_tests/services/workflow/test_scheduler.py b/api/tests/unit_tests/services/workflow/test_scheduler.py new file mode 100644 index 0000000000..90b6cb2d8b --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_scheduler.py @@ -0,0 +1,89 @@ +import pytest + +from services.workflow.entities import WorkflowScheduleCFSPlanEntity +from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand + + +class TestSchedulerCommand: + def test_enum_values(self): + assert SchedulerCommand.RESOURCE_LIMIT_REACHED == "resource_limit_reached" + assert SchedulerCommand.NONE == "none" + + def test_enum_is_str(self): + for member in SchedulerCommand: + assert isinstance(member, str) + + +class TestCFSPlanScheduler: + def test_stores_plan(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.Nop, + granularity=-1, + ) + + class ConcretePlanScheduler(CFSPlanScheduler): + def can_schedule(self): + return SchedulerCommand.NONE + + scheduler = ConcretePlanScheduler(plan) + + assert scheduler.plan is plan + assert scheduler.plan.schedule_strategy == WorkflowScheduleCFSPlanEntity.Strategy.Nop + assert scheduler.plan.granularity == -1 + + def test_cannot_instantiate_abstract(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=10, + ) + with pytest.raises(TypeError): + CFSPlanScheduler(plan) + + def test_concrete_subclass_can_schedule(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=5, + ) + + class TimedScheduler(CFSPlanScheduler): + def can_schedule(self): + if self.plan.granularity > 0: + return SchedulerCommand.NONE + return SchedulerCommand.RESOURCE_LIMIT_REACHED + + scheduler = TimedScheduler(plan) + assert scheduler.can_schedule() == SchedulerCommand.NONE + + def test_concrete_subclass_resource_limit(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=-1, + ) + + class TimedScheduler(CFSPlanScheduler): + def can_schedule(self): + if self.plan.granularity > 0: + return SchedulerCommand.NONE + return SchedulerCommand.RESOURCE_LIMIT_REACHED + + scheduler = TimedScheduler(plan) + assert scheduler.can_schedule() == SchedulerCommand.RESOURCE_LIMIT_REACHED + + +class TestWorkflowScheduleCFSPlanEntity: + def test_strategy_values(self): + assert WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice == "time-slice" + assert WorkflowScheduleCFSPlanEntity.Strategy.Nop == "nop" + + def test_default_granularity(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.Nop, + ) + assert plan.granularity == -1 + + def test_explicit_granularity(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=100, + ) + assert plan.granularity == 100 diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index 5d6fa4c137..fcdd1c2368 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock import pytest from sqlalchemy.orm import sessionmaker +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.enums import NodeType from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, @@ -22,7 +23,7 @@ def _make_service() -> WorkflowService: return WorkflowService(session_maker=sessionmaker()) -def _build_node_config(delivery_methods): +def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfigDict: node_data = HumanInputNodeData( title="Human Input", delivery_methods=delivery_methods, @@ -31,7 +32,7 @@ def _build_node_config(delivery_methods): user_actions=[], ).model_dump(mode="json") node_data["type"] = NodeType.HUMAN_INPUT.value - return {"id": "node-1", "data": node_data} + return NodeConfigDictAdapter.validate_python({"id": "node-1", "data": node_data}) def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailDeliveryMethod: diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 83c1f8d9da..9ee8f88e71 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock import pytest +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import NodeType from dify_graph.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction from dify_graph.nodes.human_input.enums import FormInputType @@ -40,6 +41,23 @@ class TestWorkflowService: workflows.append(workflow) return workflows + @pytest.fixture + def dummy_session_cls(self): + class DummySession: + def __init__(self, *args, **kwargs): + self.commit = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def begin(self): + return nullcontext() + + return DummySession + def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app): mock_app.workflow_id = None mock_session = MagicMock() @@ -169,7 +187,10 @@ class TestWorkflowService: mock_session.scalars.assert_called_once() def test_submit_human_input_form_preview_uses_rendered_content( - self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch + self, + workflow_service: WorkflowService, + monkeypatch: pytest.MonkeyPatch, + dummy_session_cls, ) -> None: service = workflow_service node_data = HumanInputNodeData( @@ -187,25 +208,15 @@ class TestWorkflowService: service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] workflow = MagicMock() - workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + node_config = NodeConfigDictAdapter.validate_python( + {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + ) + workflow.get_node_config_by_id.return_value = node_config workflow.get_enclosing_node_type_and_id.return_value = None service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] saved_outputs: dict[str, object] = {} - class DummySession: - def __init__(self, *args, **kwargs): - self.commit = MagicMock() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return nullcontext() - class DummySaver: def __init__(self, *args, **kwargs): pass @@ -213,7 +224,7 @@ class TestWorkflowService: def save(self, outputs, process_data): saved_outputs.update(outputs) - monkeypatch.setattr(workflow_service_module, "Session", DummySession) + monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver) monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) @@ -232,7 +243,7 @@ class TestWorkflowService: service._build_human_input_variable_pool.assert_called_once_with( app_model=app_model, workflow=workflow, - node_config={"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}, + node_config=node_config, manual_inputs={"#node-0.result#": "LLM output"}, ) @@ -267,12 +278,13 @@ class TestWorkflowService: service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] workflow = MagicMock() - workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( + {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + ) service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") account = SimpleNamespace(id="account-1") - with pytest.raises(ValueError) as exc_info: service.submit_human_input_form_preview( app_model=app_model, @@ -284,3 +296,119 @@ class TestWorkflowService: ) assert "Missing required inputs" in str(exc_info.value) + + def test_run_draft_workflow_node_successful_behavior( + self, workflow_service, mock_app, monkeypatch, dummy_session_cls + ): + """Behavior: When a basic workflow node runs, it correctly sets up context, + executes the node, and saves outputs.""" + service = workflow_service + account = SimpleNamespace(id="account-1") + mock_workflow = MagicMock() + mock_workflow.id = "wf-1" + mock_workflow.tenant_id = "tenant-1" + mock_workflow.environment_variables = [] + mock_workflow.conversation_variables = [] + + # Mock node config + mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( + {"id": "node-1", "data": {"type": NodeType.LLM.value}} + ) + mock_workflow.get_enclosing_node_type_and_id.return_value = None + + # Mock class methods + monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock()) + monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock()) + + # Mock workflow entry execution + mock_node_exec = MagicMock() + mock_node_exec.id = "exec-1" + mock_node_exec.process_data = {} + mock_run = MagicMock() + monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", mock_run) + + # Mock execution handling + service._handle_single_step_result = MagicMock(return_value=mock_node_exec) + + # Mock repository + mock_repo = MagicMock() + mock_repo.get_execution_by_id.return_value = mock_node_exec + mock_repo_factory = MagicMock(return_value=mock_repo) + monkeypatch.setattr( + workflow_service_module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + mock_repo_factory, + ) + service._node_execution_service_repo = mock_repo + + # Set up node execution service repo mock to return our exec node + mock_node_exec.load_full_outputs.return_value = {"output_var": "result_value"} + mock_node_exec.node_id = "node-1" + mock_node_exec.node_type = "llm" + + # Mock draft variable saver + mock_saver = MagicMock() + monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", MagicMock(return_value=mock_saver)) + + # Mock DB + monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) + + monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) + + # Act + result = service.run_draft_workflow_node( + app_model=mock_app, + draft_workflow=mock_workflow, + node_id="node-1", + user_inputs={"input_val": "test"}, + account=account, + ) + + # Assert + assert result == mock_node_exec + service._handle_single_step_result.assert_called_once() + mock_repo.save.assert_called_once_with(mock_node_exec) + mock_saver.save.assert_called_once_with(process_data={}, outputs={"output_var": "result_value"}) + + def test_run_draft_workflow_node_failure_behavior(self, workflow_service, mock_app, monkeypatch, dummy_session_cls): + """Behavior: If retrieving the saved execution fails, an appropriate error bubble matches expectations.""" + service = workflow_service + account = SimpleNamespace(id="account-1") + mock_workflow = MagicMock() + mock_workflow.tenant_id = "tenant-1" + mock_workflow.environment_variables = [] + mock_workflow.conversation_variables = [] + mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( + {"id": "node-1", "data": {"type": NodeType.LLM.value}} + ) + mock_workflow.get_enclosing_node_type_and_id.return_value = None + + monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock()) + monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock()) + monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", MagicMock()) + + mock_node_exec = MagicMock() + mock_node_exec.id = "exec-invalid" + service._handle_single_step_result = MagicMock(return_value=mock_node_exec) + + mock_repo = MagicMock() + mock_repo_factory = MagicMock(return_value=mock_repo) + monkeypatch.setattr( + workflow_service_module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + mock_repo_factory, + ) + service._node_execution_service_repo = mock_repo + + # Simulate failure to retrieve the saved execution + mock_repo.get_execution_by_id.return_value = None + + monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) + + monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) + + # Act & Assert + with pytest.raises(ValueError, match="WorkflowNodeExecution with id exec-invalid not found after saving"): + service.run_draft_workflow_node( + app_model=mock_app, draft_workflow=mock_workflow, node_id="node-1", user_inputs={}, account=account + ) diff --git a/api/tests/unit_tests/tools/__init__.py b/api/tests/unit_tests/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/uv.lock b/api/uv.lock index c989d18d56..555a980d97 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1205,29 +1205,41 @@ wheels = [ [[package]] name = "coverage" -version = "7.2.7" +version = "7.13.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/45/8b/421f30467e69ac0e414214856798d4bc32da1336df745e49e49ae5c1e2a8/coverage-7.2.7.tar.gz", hash = "sha256:924d94291ca674905fe9481f12294eb11f2d3d3fd1adb20314ba89e94f44ed59", size = 762575, upload-time = "2023-05-29T20:08:50.273Z" } +sdist = { url = "https://files.pythonhosted.org/packages/24/56/95b7e30fa389756cb56630faa728da46a27b8c6eb46f9d557c68fff12b65/coverage-7.13.4.tar.gz", hash = "sha256:e5c8f6ed1e61a8b2dcdf31eb0b9bbf0130750ca79c1c49eb898e2ad86f5ccc91", size = 827239, upload-time = "2026-02-09T12:59:03.86Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/fa/529f55c9a1029c840bcc9109d5a15ff00478b7ff550a1ae361f8745f8ad5/coverage-7.2.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:06a9a2be0b5b576c3f18f1a241f0473575c4a26021b52b2a85263a00f034d51f", size = 200895, upload-time = "2023-05-29T20:07:21.963Z" }, - { url = "https://files.pythonhosted.org/packages/67/d7/cd8fe689b5743fffac516597a1222834c42b80686b99f5b44ef43ccc2a43/coverage-7.2.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5baa06420f837184130752b7c5ea0808762083bf3487b5038d68b012e5937dbe", size = 201120, upload-time = "2023-05-29T20:07:23.765Z" }, - { url = "https://files.pythonhosted.org/packages/8c/95/16eed713202406ca0a37f8ac259bbf144c9d24f9b8097a8e6ead61da2dbb/coverage-7.2.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdec9e8cbf13a5bf63290fc6013d216a4c7232efb51548594ca3631a7f13c3a3", size = 233178, upload-time = "2023-05-29T20:07:25.281Z" }, - { url = "https://files.pythonhosted.org/packages/c1/49/4d487e2ad5d54ed82ac1101e467e8994c09d6123c91b2a962145f3d262c2/coverage-7.2.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:52edc1a60c0d34afa421c9c37078817b2e67a392cab17d97283b64c5833f427f", size = 230754, upload-time = "2023-05-29T20:07:27.044Z" }, - { url = "https://files.pythonhosted.org/packages/a7/cd/3ce94ad9d407a052dc2a74fbeb1c7947f442155b28264eb467ee78dea812/coverage-7.2.7-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63426706118b7f5cf6bb6c895dc215d8a418d5952544042c8a2d9fe87fcf09cb", size = 232558, upload-time = "2023-05-29T20:07:28.743Z" }, - { url = "https://files.pythonhosted.org/packages/8f/a8/12cc7b261f3082cc299ab61f677f7e48d93e35ca5c3c2f7241ed5525ccea/coverage-7.2.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:afb17f84d56068a7c29f5fa37bfd38d5aba69e3304af08ee94da8ed5b0865833", size = 241509, upload-time = "2023-05-29T20:07:30.434Z" }, - { url = "https://files.pythonhosted.org/packages/04/fa/43b55101f75a5e9115259e8be70ff9279921cb6b17f04c34a5702ff9b1f7/coverage-7.2.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:48c19d2159d433ccc99e729ceae7d5293fbffa0bdb94952d3579983d1c8c9d97", size = 239924, upload-time = "2023-05-29T20:07:32.065Z" }, - { url = "https://files.pythonhosted.org/packages/68/5f/d2bd0f02aa3c3e0311986e625ccf97fdc511b52f4f1a063e4f37b624772f/coverage-7.2.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e1f928eaf5469c11e886fe0885ad2bf1ec606434e79842a879277895a50942a", size = 240977, upload-time = "2023-05-29T20:07:34.184Z" }, - { url = "https://files.pythonhosted.org/packages/ba/92/69c0722882643df4257ecc5437b83f4c17ba9e67f15dc6b77bad89b6982e/coverage-7.2.7-cp311-cp311-win32.whl", hash = "sha256:33d6d3ea29d5b3a1a632b3c4e4f4ecae24ef170b0b9ee493883f2df10039959a", size = 203168, upload-time = "2023-05-29T20:07:35.869Z" }, - { url = "https://files.pythonhosted.org/packages/b1/96/c12ed0dfd4ec587f3739f53eb677b9007853fd486ccb0e7d5512a27bab2e/coverage-7.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:5b7540161790b2f28143191f5f8ec02fb132660ff175b7747b95dcb77ac26562", size = 204185, upload-time = "2023-05-29T20:07:37.39Z" }, - { url = "https://files.pythonhosted.org/packages/ff/d5/52fa1891d1802ab2e1b346d37d349cb41cdd4fd03f724ebbf94e80577687/coverage-7.2.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f2f67fe12b22cd130d34d0ef79206061bfb5eda52feb6ce0dba0644e20a03cf4", size = 201020, upload-time = "2023-05-29T20:07:38.724Z" }, - { url = "https://files.pythonhosted.org/packages/24/df/6765898d54ea20e3197a26d26bb65b084deefadd77ce7de946b9c96dfdc5/coverage-7.2.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a342242fe22407f3c17f4b499276a02b01e80f861f1682ad1d95b04018e0c0d4", size = 233994, upload-time = "2023-05-29T20:07:40.274Z" }, - { url = "https://files.pythonhosted.org/packages/15/81/b108a60bc758b448c151e5abceed027ed77a9523ecbc6b8a390938301841/coverage-7.2.7-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:171717c7cb6b453aebac9a2ef603699da237f341b38eebfee9be75d27dc38e01", size = 231358, upload-time = "2023-05-29T20:07:41.998Z" }, - { url = "https://files.pythonhosted.org/packages/61/90/c76b9462f39897ebd8714faf21bc985b65c4e1ea6dff428ea9dc711ed0dd/coverage-7.2.7-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49969a9f7ffa086d973d91cec8d2e31080436ef0fb4a359cae927e742abfaaa6", size = 233316, upload-time = "2023-05-29T20:07:43.539Z" }, - { url = "https://files.pythonhosted.org/packages/04/d6/8cba3bf346e8b1a4fb3f084df7d8cea25a6b6c56aaca1f2e53829be17e9e/coverage-7.2.7-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b46517c02ccd08092f4fa99f24c3b83d8f92f739b4657b0f146246a0ca6a831d", size = 240159, upload-time = "2023-05-29T20:07:44.982Z" }, - { url = "https://files.pythonhosted.org/packages/6e/ea/4a252dc77ca0605b23d477729d139915e753ee89e4c9507630e12ad64a80/coverage-7.2.7-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a3d33a6b3eae87ceaefa91ffdc130b5e8536182cd6dfdbfc1aa56b46ff8c86de", size = 238127, upload-time = "2023-05-29T20:07:46.522Z" }, - { url = "https://files.pythonhosted.org/packages/9f/5c/d9760ac497c41f9c4841f5972d0edf05d50cad7814e86ee7d133ec4a0ac8/coverage-7.2.7-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:976b9c42fb2a43ebf304fa7d4a310e5f16cc99992f33eced91ef6f908bd8f33d", size = 239833, upload-time = "2023-05-29T20:07:47.992Z" }, - { url = "https://files.pythonhosted.org/packages/69/8c/26a95b08059db1cbb01e4b0e6d40f2e9debb628c6ca86b78f625ceaf9bab/coverage-7.2.7-cp312-cp312-win32.whl", hash = "sha256:8de8bb0e5ad103888d65abef8bca41ab93721647590a3f740100cd65c3b00511", size = 203463, upload-time = "2023-05-29T20:07:49.939Z" }, - { url = "https://files.pythonhosted.org/packages/b7/00/14b00a0748e9eda26e97be07a63cc911108844004687321ddcc213be956c/coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3", size = 204347, upload-time = "2023-05-29T20:07:51.909Z" }, + { url = "https://files.pythonhosted.org/packages/b4/ad/b59e5b451cf7172b8d1043dc0fa718f23aab379bc1521ee13d4bd9bfa960/coverage-7.13.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d490ba50c3f35dd7c17953c68f3270e7ccd1c6642e2d2afe2d8e720b98f5a053", size = 219278, upload-time = "2026-02-09T12:56:31.673Z" }, + { url = "https://files.pythonhosted.org/packages/f1/17/0cb7ca3de72e5f4ef2ec2fa0089beafbcaaaead1844e8b8a63d35173d77d/coverage-7.13.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:19bc3c88078789f8ef36acb014d7241961dbf883fd2533d18cb1e7a5b4e28b11", size = 219783, upload-time = "2026-02-09T12:56:33.104Z" }, + { url = "https://files.pythonhosted.org/packages/ab/63/325d8e5b11e0eaf6d0f6a44fad444ae58820929a9b0de943fa377fe73e85/coverage-7.13.4-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3998e5a32e62fdf410c0dbd3115df86297995d6e3429af80b8798aad894ca7aa", size = 250200, upload-time = "2026-02-09T12:56:34.474Z" }, + { url = "https://files.pythonhosted.org/packages/76/53/c16972708cbb79f2942922571a687c52bd109a7bd51175aeb7558dff2236/coverage-7.13.4-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8e264226ec98e01a8e1054314af91ee6cde0eacac4f465cc93b03dbe0bce2fd7", size = 252114, upload-time = "2026-02-09T12:56:35.749Z" }, + { url = "https://files.pythonhosted.org/packages/eb/c2/7ab36d8b8cc412bec9ea2d07c83c48930eb4ba649634ba00cb7e4e0f9017/coverage-7.13.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a3aa4e7b9e416774b21797365b358a6e827ffadaaca81b69ee02946852449f00", size = 254220, upload-time = "2026-02-09T12:56:37.796Z" }, + { url = "https://files.pythonhosted.org/packages/d6/4d/cf52c9a3322c89a0e6febdfbc83bb45c0ed3c64ad14081b9503adee702e7/coverage-7.13.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:71ca20079dd8f27fcf808817e281e90220475cd75115162218d0e27549f95fef", size = 256164, upload-time = "2026-02-09T12:56:39.016Z" }, + { url = "https://files.pythonhosted.org/packages/78/e9/eb1dd17bd6de8289df3580e967e78294f352a5df8a57ff4671ee5fc3dcd0/coverage-7.13.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e2f25215f1a359ab17320b47bcdaca3e6e6356652e8256f2441e4ef972052903", size = 250325, upload-time = "2026-02-09T12:56:40.668Z" }, + { url = "https://files.pythonhosted.org/packages/71/07/8c1542aa873728f72267c07278c5cc0ec91356daf974df21335ccdb46368/coverage-7.13.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d65b2d373032411e86960604dc4edac91fdfb5dca539461cf2cbe78327d1e64f", size = 251913, upload-time = "2026-02-09T12:56:41.97Z" }, + { url = "https://files.pythonhosted.org/packages/74/d7/c62e2c5e4483a748e27868e4c32ad3daa9bdddbba58e1bc7a15e252baa74/coverage-7.13.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:94eb63f9b363180aff17de3e7c8760c3ba94664ea2695c52f10111244d16a299", size = 249974, upload-time = "2026-02-09T12:56:43.323Z" }, + { url = "https://files.pythonhosted.org/packages/98/9f/4c5c015a6e98ced54efd0f5cf8d31b88e5504ecb6857585fc0161bb1e600/coverage-7.13.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e856bf6616714c3a9fbc270ab54103f4e685ba236fa98c054e8f87f266c93505", size = 253741, upload-time = "2026-02-09T12:56:45.155Z" }, + { url = "https://files.pythonhosted.org/packages/bd/59/0f4eef89b9f0fcd9633b5d350016f54126ab49426a70ff4c4e87446cabdc/coverage-7.13.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:65dfcbe305c3dfe658492df2d85259e0d79ead4177f9ae724b6fb245198f55d6", size = 249695, upload-time = "2026-02-09T12:56:46.636Z" }, + { url = "https://files.pythonhosted.org/packages/b5/2c/b7476f938deb07166f3eb281a385c262675d688ff4659ad56c6c6b8e2e70/coverage-7.13.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b507778ae8a4c915436ed5c2e05b4a6cecfa70f734e19c22a005152a11c7b6a9", size = 250599, upload-time = "2026-02-09T12:56:48.13Z" }, + { url = "https://files.pythonhosted.org/packages/b8/34/c3420709d9846ee3785b9f2831b4d94f276f38884032dca1457fa83f7476/coverage-7.13.4-cp311-cp311-win32.whl", hash = "sha256:784fc3cf8be001197b652d51d3fd259b1e2262888693a4636e18879f613a62a9", size = 221780, upload-time = "2026-02-09T12:56:50.479Z" }, + { url = "https://files.pythonhosted.org/packages/61/08/3d9c8613079d2b11c185b865de9a4c1a68850cfda2b357fae365cf609f29/coverage-7.13.4-cp311-cp311-win_amd64.whl", hash = "sha256:2421d591f8ca05b308cf0092807308b2facbefe54af7c02ac22548b88b95c98f", size = 222715, upload-time = "2026-02-09T12:56:51.815Z" }, + { url = "https://files.pythonhosted.org/packages/18/1a/54c3c80b2f056164cc0a6cdcb040733760c7c4be9d780fe655f356f433e4/coverage-7.13.4-cp311-cp311-win_arm64.whl", hash = "sha256:79e73a76b854d9c6088fe5d8b2ebe745f8681c55f7397c3c0a016192d681045f", size = 221385, upload-time = "2026-02-09T12:56:53.194Z" }, + { url = "https://files.pythonhosted.org/packages/d1/81/4ce2fdd909c5a0ed1f6dedb88aa57ab79b6d1fbd9b588c1ac7ef45659566/coverage-7.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:02231499b08dabbe2b96612993e5fc34217cdae907a51b906ac7fca8027a4459", size = 219449, upload-time = "2026-02-09T12:56:54.889Z" }, + { url = "https://files.pythonhosted.org/packages/5d/96/5238b1efc5922ddbdc9b0db9243152c09777804fb7c02ad1741eb18a11c0/coverage-7.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40aa8808140e55dc022b15d8aa7f651b6b3d68b365ea0398f1441e0b04d859c3", size = 219810, upload-time = "2026-02-09T12:56:56.33Z" }, + { url = "https://files.pythonhosted.org/packages/78/72/2f372b726d433c9c35e56377cf1d513b4c16fe51841060d826b95caacec1/coverage-7.13.4-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5b856a8ccf749480024ff3bd7310adaef57bf31fd17e1bfc404b7940b6986634", size = 251308, upload-time = "2026-02-09T12:56:57.858Z" }, + { url = "https://files.pythonhosted.org/packages/5d/a0/2ea570925524ef4e00bb6c82649f5682a77fac5ab910a65c9284de422600/coverage-7.13.4-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2c048ea43875fbf8b45d476ad79f179809c590ec7b79e2035c662e7afa3192e3", size = 254052, upload-time = "2026-02-09T12:56:59.754Z" }, + { url = "https://files.pythonhosted.org/packages/e8/ac/45dc2e19a1939098d783c846e130b8f862fbb50d09e0af663988f2f21973/coverage-7.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b7b38448866e83176e28086674fe7368ab8590e4610fb662b44e345b86d63ffa", size = 255165, upload-time = "2026-02-09T12:57:01.287Z" }, + { url = "https://files.pythonhosted.org/packages/2d/4d/26d236ff35abc3b5e63540d3386e4c3b192168c1d96da5cb2f43c640970f/coverage-7.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:de6defc1c9badbf8b9e67ae90fd00519186d6ab64e5cc5f3d21359c2a9b2c1d3", size = 257432, upload-time = "2026-02-09T12:57:02.637Z" }, + { url = "https://files.pythonhosted.org/packages/ec/55/14a966c757d1348b2e19caf699415a2a4c4f7feaa4bbc6326a51f5c7dd1b/coverage-7.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7eda778067ad7ffccd23ecffce537dface96212576a07924cbf0d8799d2ded5a", size = 251716, upload-time = "2026-02-09T12:57:04.056Z" }, + { url = "https://files.pythonhosted.org/packages/77/33/50116647905837c66d28b2af1321b845d5f5d19be9655cb84d4a0ea806b4/coverage-7.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e87f6c587c3f34356c3759f0420693e35e7eb0e2e41e4c011cb6ec6ecbbf1db7", size = 253089, upload-time = "2026-02-09T12:57:05.503Z" }, + { url = "https://files.pythonhosted.org/packages/c2/b4/8efb11a46e3665d92635a56e4f2d4529de6d33f2cb38afd47d779d15fc99/coverage-7.13.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:8248977c2e33aecb2ced42fef99f2d319e9904a36e55a8a68b69207fb7e43edc", size = 251232, upload-time = "2026-02-09T12:57:06.879Z" }, + { url = "https://files.pythonhosted.org/packages/51/24/8cd73dd399b812cc76bb0ac260e671c4163093441847ffe058ac9fda1e32/coverage-7.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:25381386e80ae727608e662474db537d4df1ecd42379b5ba33c84633a2b36d47", size = 255299, upload-time = "2026-02-09T12:57:08.245Z" }, + { url = "https://files.pythonhosted.org/packages/03/94/0a4b12f1d0e029ce1ccc1c800944a9984cbe7d678e470bb6d3c6bc38a0da/coverage-7.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:ee756f00726693e5ba94d6df2bdfd64d4852d23b09bb0bc700e3b30e6f333985", size = 250796, upload-time = "2026-02-09T12:57:10.142Z" }, + { url = "https://files.pythonhosted.org/packages/73/44/6002fbf88f6698ca034360ce474c406be6d5a985b3fdb3401128031eef6b/coverage-7.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fdfc1e28e7c7cdce44985b3043bc13bbd9c747520f94a4d7164af8260b3d91f0", size = 252673, upload-time = "2026-02-09T12:57:12.197Z" }, + { url = "https://files.pythonhosted.org/packages/de/c6/a0279f7c00e786be75a749a5674e6fa267bcbd8209cd10c9a450c655dfa7/coverage-7.13.4-cp312-cp312-win32.whl", hash = "sha256:01d4cbc3c283a17fc1e42d614a119f7f438eabb593391283adca8dc86eff1246", size = 221990, upload-time = "2026-02-09T12:57:14.085Z" }, + { url = "https://files.pythonhosted.org/packages/77/4e/c0a25a425fcf5557d9abd18419c95b63922e897bc86c1f327f155ef234a9/coverage-7.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:9401ebc7ef522f01d01d45532c68c5ac40fb27113019b6b7d8b208f6e9baa126", size = 222800, upload-time = "2026-02-09T12:57:15.944Z" }, + { url = "https://files.pythonhosted.org/packages/47/ac/92da44ad9a6f4e3a7debd178949d6f3769bedca33830ce9b1dcdab589a37/coverage-7.13.4-cp312-cp312-win_arm64.whl", hash = "sha256:b1ec7b6b6e93255f952e27ab58fbc68dcc468844b16ecbee881aeb29b6ab4d8d", size = 221415, upload-time = "2026-02-09T12:57:17.497Z" }, + { url = "https://files.pythonhosted.org/packages/0d/4a/331fe2caf6799d591109bb9c08083080f6de90a823695d412a935622abb2/coverage-7.13.4-py3-none-any.whl", hash = "sha256:1af1641e57cf7ba1bd67d677c9abdbcd6cc2ab7da3bca7fa1e2b7e50e65f2ad0", size = 211242, upload-time = "2026-02-09T12:59:02.032Z" }, ] [package.optional-dependencies] @@ -1672,7 +1684,7 @@ requires-dist = [ { name = "opentelemetry-instrumentation-httpx", specifier = "==0.49b0" }, { name = "opentelemetry-instrumentation-redis", specifier = "==0.49b0" }, { name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.49b0" }, - { name = "opentelemetry-propagator-b3", specifier = "==1.28.0" }, + { name = "opentelemetry-propagator-b3", specifier = "==1.40.0" }, { name = "opentelemetry-proto", specifier = "==1.28.0" }, { name = "opentelemetry-sdk", specifier = "==1.28.0" }, { name = "opentelemetry-semantic-conventions", specifier = "==0.49b0" }, @@ -1686,7 +1698,7 @@ requires-dist = [ { name = "pydantic", specifier = "~=2.12.5" }, { name = "pydantic-extra-types", specifier = "~=2.11.0" }, { name = "pydantic-settings", specifier = "~=2.13.1" }, - { name = "pyjwt", specifier = "~=2.11.0" }, + { name = "pyjwt", specifier = "~=2.12.0" }, { name = "pypdfium2", specifier = "==5.2.0" }, { name = "python-docx", specifier = "~=1.2.0" }, { name = "python-dotenv", specifier = "==1.0.1" }, @@ -1713,48 +1725,48 @@ dev = [ { name = "basedpyright", specifier = "~=1.38.2" }, { name = "boto3-stubs", specifier = ">=1.38.20" }, { name = "celery-types", specifier = ">=0.23.0" }, - { name = "coverage", specifier = "~=7.2.4" }, - { name = "dotenv-linter", specifier = "~=0.5.0" }, - { name = "faker", specifier = "~=38.2.0" }, + { name = "coverage", specifier = "~=7.13.4" }, + { name = "dotenv-linter", specifier = "~=0.7.0" }, + { name = "faker", specifier = "~=40.8.0" }, { name = "hypothesis", specifier = ">=6.131.15" }, { name = "import-linter", specifier = ">=2.3" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "mypy", specifier = "~=1.19.1" }, - { name = "pandas-stubs", specifier = "~=2.2.3" }, + { name = "pandas-stubs", specifier = "~=3.0.0" }, { name = "pyrefly", specifier = ">=0.55.0" }, - { name = "pytest", specifier = "~=8.3.2" }, - { name = "pytest-benchmark", specifier = "~=4.0.0" }, - { name = "pytest-cov", specifier = "~=4.1.0" }, + { name = "pytest", specifier = "~=9.0.2" }, + { name = "pytest-benchmark", specifier = "~=5.2.3" }, + { name = "pytest-cov", specifier = "~=7.0.0" }, { name = "pytest-env", specifier = "~=1.1.3" }, - { name = "pytest-mock", specifier = "~=3.14.0" }, + { name = "pytest-mock", specifier = "~=3.15.1" }, { name = "pytest-timeout", specifier = ">=2.4.0" }, { name = "pytest-xdist", specifier = ">=3.8.0" }, - { name = "ruff", specifier = "~=0.14.0" }, + { name = "ruff", specifier = "~=0.15.5" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, { name = "sseclient-py", specifier = ">=1.8.0" }, { name = "testcontainers", specifier = "~=4.13.2" }, { name = "types-aiofiles", specifier = "~=25.1.0" }, { name = "types-beautifulsoup4", specifier = "~=4.12.0" }, - { name = "types-cachetools", specifier = "~=5.5.0" }, + { name = "types-cachetools", specifier = "~=6.2.0" }, { name = "types-cffi", specifier = ">=1.17.0" }, { name = "types-colorama", specifier = "~=0.4.15" }, { name = "types-defusedxml", specifier = "~=0.7.0" }, - { name = "types-deprecated", specifier = "~=1.2.15" }, - { name = "types-docutils", specifier = "~=0.21.0" }, - { name = "types-flask-cors", specifier = "~=5.0.0" }, + { name = "types-deprecated", specifier = "~=1.3.1" }, + { name = "types-docutils", specifier = "~=0.22.3" }, + { name = "types-flask-cors", specifier = "~=6.0.0" }, { name = "types-flask-migrate", specifier = "~=4.1.0" }, { name = "types-gevent", specifier = "~=25.9.0" }, { name = "types-greenlet", specifier = "~=3.3.0" }, { name = "types-html5lib", specifier = "~=1.1.11" }, { name = "types-jmespath", specifier = ">=1.0.2.20240106" }, - { name = "types-jsonschema", specifier = "~=4.23.0" }, + { name = "types-jsonschema", specifier = "~=4.26.0" }, { name = "types-markdown", specifier = "~=3.10.2" }, - { name = "types-oauthlib", specifier = "~=3.2.0" }, + { name = "types-oauthlib", specifier = "~=3.3.0" }, { name = "types-objgraph", specifier = "~=3.6.0" }, { name = "types-olefile", specifier = "~=0.47.0" }, { name = "types-openpyxl", specifier = "~=3.1.5" }, { name = "types-pexpect", specifier = "~=4.9.0" }, - { name = "types-protobuf", specifier = "~=5.29.1" }, + { name = "types-protobuf", specifier = "~=6.32.1" }, { name = "types-psutil", specifier = "~=7.2.2" }, { name = "types-psycopg2", specifier = "~=2.9.21" }, { name = "types-pygments", specifier = "~=2.19.0" }, @@ -1762,10 +1774,10 @@ dev = [ { name = "types-pyopenssl", specifier = ">=24.1.0" }, { name = "types-python-dateutil", specifier = "~=2.9.0" }, { name = "types-python-http-client", specifier = ">=3.3.7.20240910" }, - { name = "types-pywin32", specifier = "~=310.0.0" }, + { name = "types-pywin32", specifier = "~=311.0.0" }, { name = "types-pyyaml", specifier = "~=6.0.12" }, { name = "types-redis", specifier = ">=4.6.0.20241004" }, - { name = "types-regex", specifier = "~=2024.11.6" }, + { name = "types-regex", specifier = "~=2026.2.28" }, { name = "types-setuptools", specifier = ">=80.9.0" }, { name = "types-shapely", specifier = "~=2.1.0" }, { name = "types-simplejson", specifier = ">=3.20.0" }, @@ -1860,18 +1872,18 @@ wheels = [ [[package]] name = "dotenv-linter" -version = "0.5.0" +version = "0.7.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, { name = "click" }, { name = "click-default-group" }, - { name = "ply" }, + { name = "lark" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ef/fe/77e184ccc312f6263cbcc48a9579eec99f5c7ff72a9b1bd7812cafc22bbb/dotenv_linter-0.5.0.tar.gz", hash = "sha256:4862a8393e5ecdfb32982f1b32dbc006fff969a7b3c8608ba7db536108beeaea", size = 15346, upload-time = "2024-03-13T11:52:10.52Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1e/e5/515ca4e069b70ba0be477ab0a193855c08066f9ef1a9350dcfbdc8f12f87/dotenv_linter-0.7.0.tar.gz", hash = "sha256:24ed93c1028d6305d6787e51773badf3346e53012ad4f5ada9cf747d2da6de13", size = 14033, upload-time = "2025-04-28T17:40:00.771Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/01/62ed4374340e6cf17c5084828974d96db8085e4018439ac41dc3cbbbcab3/dotenv_linter-0.5.0-py3-none-any.whl", hash = "sha256:fd01cca7f2140cb1710f49cbc1bf0e62397a75a6f0522d26a8b9b2331143c8bd", size = 21770, upload-time = "2024-03-13T11:52:08.607Z" }, + { url = "https://files.pythonhosted.org/packages/6e/5e/e26881b8d6bd6498c1a7225fba8ead3626a9f4b2d7d29dd272a875753d0d/dotenv_linter-0.7.0-py3-none-any.whl", hash = "sha256:0ffdf0c7435bd638aba5ff6cc9ea53bf093488bf1c722e363e902008659bb1fb", size = 19806, upload-time = "2025-04-28T17:39:58.395Z" }, ] [[package]] @@ -1965,14 +1977,14 @@ wheels = [ [[package]] name = "faker" -version = "38.2.0" +version = "40.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "tzdata" }, + { name = "tzdata", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/64/27/022d4dbd4c20567b4c294f79a133cc2f05240ea61e0d515ead18c995c249/faker-38.2.0.tar.gz", hash = "sha256:20672803db9c7cb97f9b56c18c54b915b6f1d8991f63d1d673642dc43f5ce7ab", size = 1941469, upload-time = "2025-11-19T16:37:31.892Z" } +sdist = { url = "https://files.pythonhosted.org/packages/70/03/14428edc541467c460d363f6e94bee9acc271f3e62470630fc9a647d0cf2/faker-40.8.0.tar.gz", hash = "sha256:936a3c9be6c004433f20aa4d99095df5dec82b8c7ad07459756041f8c1728875", size = 1956493, upload-time = "2026-03-04T16:18:48.161Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/17/93/00c94d45f55c336434a15f98d906387e87ce28f9918e4444829a8fda432d/faker-38.2.0-py3-none-any.whl", hash = "sha256:35fe4a0a79dee0dc4103a6083ee9224941e7d3594811a50e3969e547b0d2ee65", size = 1980505, upload-time = "2025-11-19T16:37:30.208Z" }, + { url = "https://files.pythonhosted.org/packages/4c/3b/c6348f1e285e75b069085b18110a4e6325b763a5d35d5e204356fc7c20b3/faker-40.8.0-py3-none-any.whl", hash = "sha256:eb21bdba18f7a8375382eb94fb436fce07046893dc94cb20817d28deb0c3d579", size = 1989124, upload-time = "2026-03-04T16:18:46.45Z" }, ] [[package]] @@ -2039,11 +2051,11 @@ wheels = [ [[package]] name = "fickling" -version = "0.1.9" +version = "0.1.10" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/25/bd/ca7127df0201596b0b30f9ab3d36e565bb9d6f8f4da1560758b817e81b65/fickling-0.1.9.tar.gz", hash = "sha256:bb518c2fd833555183bc46b6903bb4022f3ae0436a69c3fb149cfc75eebaac33", size = 336940, upload-time = "2026-03-03T23:32:19.449Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9f/06/1818b8f52267599e54041349c553d5894e17ec8a539a246eb3f9eaf05629/fickling-0.1.10.tar.gz", hash = "sha256:8c8b76abd29936f1a5932e4087b8c8becb2d7ab1cf08549e63519ebcb2f71644", size = 338062, upload-time = "2026-03-13T16:34:29.287Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/92/49/c597bad508c74917901432b41ae5a8f036839a7fb8d0d29a89765f5d3643/fickling-0.1.9-py3-none-any.whl", hash = "sha256:ccc3ce3b84733406ade2fe749717f6e428047335157c6431eefd3e7e970a06d1", size = 52786, upload-time = "2026-03-03T23:32:17.533Z" }, + { url = "https://files.pythonhosted.org/packages/05/86/620960dff970da5311f05e25fc045dac8495557d51030e5a0827084b18fd/fickling-0.1.10-py3-none-any.whl", hash = "sha256:962c35c38ece1b3632fc119c0f4cb1eebc02dc6d65bfd93a1803afd42ca91d25", size = 52853, upload-time = "2026-03-13T16:34:27.821Z" }, ] [[package]] @@ -3322,6 +3334,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/a8/4202ca65561213ec84ca3800b1d4e5d37a1441cddeec533367ecbca7f408/langsmith-0.7.16-py3-none-any.whl", hash = "sha256:c84a7a06938025fe0aad992acc546dd75ce3f757ba8ee5b00ad914911d4fc02e", size = 347538, upload-time = "2026-03-09T21:11:15.02Z" }, ] +[[package]] +name = "lark" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/da/34/28fff3ab31ccff1fd4f6c7c7b0ceb2b6968d8ea4950663eadcb5720591a0/lark-1.3.1.tar.gz", hash = "sha256:b426a7a6d6d53189d318f2b6236ab5d6429eaf09259f1ca33eb716eed10d2905", size = 382732, upload-time = "2025-10-27T18:25:56.653Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/3d/14ce75ef66813643812f3093ab17e46d3a206942ce7376d31ec2d36229e7/lark-1.3.1-py3-none-any.whl", hash = "sha256:c629b661023a014c37da873b4ff58a817398d12635d3bbb2c5a03be7fe5d1e12", size = 113151, upload-time = "2025-10-27T18:25:54.882Z" }, +] + [[package]] name = "librt" version = "0.8.1" @@ -4314,15 +4335,15 @@ wheels = [ [[package]] name = "opentelemetry-propagator-b3" -version = "1.28.0" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "deprecated" }, { name = "opentelemetry-api" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/1d/225ea036785119964509e92f4e1bc0313ba6ec790fbf51bd363abafeafae/opentelemetry_propagator_b3-1.28.0.tar.gz", hash = "sha256:cf6f0d2a1881c4858898be47e8a94b11bc5b16fc73b6c37ebfa2121c4825adc6", size = 9592, upload-time = "2024-11-05T19:14:57.193Z" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/fe/e0c84af5c654ec42165ba57af83c7f67e4b8af77f836ddc29dee59ff73c6/opentelemetry_propagator_b3-1.40.0.tar.gz", hash = "sha256:59b6925498947c08a1b7e0dd38193ff97e5009bec74ec23824300c2e32f77bcf", size = 9587, upload-time = "2026-03-04T14:17:30.079Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4e/fa/438d53d73a6c45df5d416b56dc371a65d0b07859bc107ab632349a079d4a/opentelemetry_propagator_b3-1.28.0-py3-none-any.whl", hash = "sha256:9f6923a5da56d7da6724e4fdd758a67ede2a2732efb929e538cf6fea337700c5", size = 8917, upload-time = "2024-11-05T19:14:37.317Z" }, + { url = "https://files.pythonhosted.org/packages/8f/84/8654cc0539b5145046b2e60d058cebad401a600dd0b1240f1711c6788643/opentelemetry_propagator_b3-1.40.0-py3-none-any.whl", hash = "sha256:cb72a1698fd1d1b434f70dc90c1de62da8ade1dd84850d1f040eccf6a420fa7b", size = 8922, upload-time = "2026-03-04T14:17:14.732Z" }, ] [[package]] @@ -4440,40 +4461,40 @@ wheels = [ [[package]] name = "orjson" -version = "3.11.4" +version = "3.11.6" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c6/fe/ed708782d6709cc60eb4c2d8a361a440661f74134675c72990f2c48c785f/orjson-3.11.4.tar.gz", hash = "sha256:39485f4ab4c9b30a3943cfe99e1a213c4776fb69e8abd68f66b83d5a0b0fdc6d", size = 5945188, upload-time = "2025-10-24T15:50:38.027Z" } +sdist = { url = "https://files.pythonhosted.org/packages/70/a3/4e09c61a5f0c521cba0bb433639610ae037437669f1a4cbc93799e731d78/orjson-3.11.6.tar.gz", hash = "sha256:0a54c72259f35299fd033042367df781c2f66d10252955ca1efb7db309b954cb", size = 6175856, upload-time = "2026-01-29T15:13:07.942Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/63/1d/1ea6005fffb56715fd48f632611e163d1604e8316a5bad2288bee9a1c9eb/orjson-3.11.4-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:5e59d23cd93ada23ec59a96f215139753fbfe3a4d989549bcb390f8c00370b39", size = 243498, upload-time = "2025-10-24T15:48:48.101Z" }, - { url = "https://files.pythonhosted.org/packages/37/d7/ffed10c7da677f2a9da307d491b9eb1d0125b0307019c4ad3d665fd31f4f/orjson-3.11.4-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:5c3aedecfc1beb988c27c79d52ebefab93b6c3921dbec361167e6559aba2d36d", size = 128961, upload-time = "2025-10-24T15:48:49.571Z" }, - { url = "https://files.pythonhosted.org/packages/a2/96/3e4d10a18866d1368f73c8c44b7fe37cc8a15c32f2a7620be3877d4c55a3/orjson-3.11.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da9e5301f1c2caa2a9a4a303480d79c9ad73560b2e7761de742ab39fe59d9175", size = 130321, upload-time = "2025-10-24T15:48:50.713Z" }, - { url = "https://files.pythonhosted.org/packages/eb/1f/465f66e93f434f968dd74d5b623eb62c657bdba2332f5a8be9f118bb74c7/orjson-3.11.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8873812c164a90a79f65368f8f96817e59e35d0cc02786a5356f0e2abed78040", size = 129207, upload-time = "2025-10-24T15:48:52.193Z" }, - { url = "https://files.pythonhosted.org/packages/28/43/d1e94837543321c119dff277ae8e348562fe8c0fafbb648ef7cb0c67e521/orjson-3.11.4-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5d7feb0741ebb15204e748f26c9638e6665a5fa93c37a2c73d64f1669b0ddc63", size = 136323, upload-time = "2025-10-24T15:48:54.806Z" }, - { url = "https://files.pythonhosted.org/packages/bf/04/93303776c8890e422a5847dd012b4853cdd88206b8bbd3edc292c90102d1/orjson-3.11.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:01ee5487fefee21e6910da4c2ee9eef005bee568a0879834df86f888d2ffbdd9", size = 137440, upload-time = "2025-10-24T15:48:56.326Z" }, - { url = "https://files.pythonhosted.org/packages/1e/ef/75519d039e5ae6b0f34d0336854d55544ba903e21bf56c83adc51cd8bf82/orjson-3.11.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d40d46f348c0321df01507f92b95a377240c4ec31985225a6668f10e2676f9a", size = 136680, upload-time = "2025-10-24T15:48:57.476Z" }, - { url = "https://files.pythonhosted.org/packages/b5/18/bf8581eaae0b941b44efe14fee7b7862c3382fbc9a0842132cfc7cf5ecf4/orjson-3.11.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95713e5fc8af84d8edc75b785d2386f653b63d62b16d681687746734b4dfc0be", size = 136160, upload-time = "2025-10-24T15:48:59.631Z" }, - { url = "https://files.pythonhosted.org/packages/c4/35/a6d582766d351f87fc0a22ad740a641b0a8e6fc47515e8614d2e4790ae10/orjson-3.11.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ad73ede24f9083614d6c4ca9a85fe70e33be7bf047ec586ee2363bc7418fe4d7", size = 140318, upload-time = "2025-10-24T15:49:00.834Z" }, - { url = "https://files.pythonhosted.org/packages/76/b3/5a4801803ab2e2e2d703bce1a56540d9f99a9143fbec7bf63d225044fef8/orjson-3.11.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:842289889de515421f3f224ef9c1f1efb199a32d76d8d2ca2706fa8afe749549", size = 406330, upload-time = "2025-10-24T15:49:02.327Z" }, - { url = "https://files.pythonhosted.org/packages/80/55/a8f682f64833e3a649f620eafefee175cbfeb9854fc5b710b90c3bca45df/orjson-3.11.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:3b2427ed5791619851c52a1261b45c233930977e7de8cf36de05636c708fa905", size = 149580, upload-time = "2025-10-24T15:49:03.517Z" }, - { url = "https://files.pythonhosted.org/packages/ad/e4/c132fa0c67afbb3eb88274fa98df9ac1f631a675e7877037c611805a4413/orjson-3.11.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3c36e524af1d29982e9b190573677ea02781456b2e537d5840e4538a5ec41907", size = 139846, upload-time = "2025-10-24T15:49:04.761Z" }, - { url = "https://files.pythonhosted.org/packages/54/06/dc3491489efd651fef99c5908e13951abd1aead1257c67f16135f95ce209/orjson-3.11.4-cp311-cp311-win32.whl", hash = "sha256:87255b88756eab4a68ec61837ca754e5d10fa8bc47dc57f75cedfeaec358d54c", size = 135781, upload-time = "2025-10-24T15:49:05.969Z" }, - { url = "https://files.pythonhosted.org/packages/79/b7/5e5e8d77bd4ea02a6ac54c42c818afb01dd31961be8a574eb79f1d2cfb1e/orjson-3.11.4-cp311-cp311-win_amd64.whl", hash = "sha256:e2d5d5d798aba9a0e1fede8d853fa899ce2cb930ec0857365f700dffc2c7af6a", size = 131391, upload-time = "2025-10-24T15:49:07.355Z" }, - { url = "https://files.pythonhosted.org/packages/0f/dc/9484127cc1aa213be398ed735f5f270eedcb0c0977303a6f6ddc46b60204/orjson-3.11.4-cp311-cp311-win_arm64.whl", hash = "sha256:6bb6bb41b14c95d4f2702bce9975fda4516f1db48e500102fc4d8119032ff045", size = 126252, upload-time = "2025-10-24T15:49:08.869Z" }, - { url = "https://files.pythonhosted.org/packages/63/51/6b556192a04595b93e277a9ff71cd0cc06c21a7df98bcce5963fa0f5e36f/orjson-3.11.4-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:d4371de39319d05d3f482f372720b841c841b52f5385bd99c61ed69d55d9ab50", size = 243571, upload-time = "2025-10-24T15:49:10.008Z" }, - { url = "https://files.pythonhosted.org/packages/1c/2c/2602392ddf2601d538ff11848b98621cd465d1a1ceb9db9e8043181f2f7b/orjson-3.11.4-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:e41fd3b3cac850eaae78232f37325ed7d7436e11c471246b87b2cd294ec94853", size = 128891, upload-time = "2025-10-24T15:49:11.297Z" }, - { url = "https://files.pythonhosted.org/packages/4e/47/bf85dcf95f7a3a12bf223394a4f849430acd82633848d52def09fa3f46ad/orjson-3.11.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:600e0e9ca042878c7fdf189cf1b028fe2c1418cc9195f6cb9824eb6ed99cb938", size = 130137, upload-time = "2025-10-24T15:49:12.544Z" }, - { url = "https://files.pythonhosted.org/packages/b4/4d/a0cb31007f3ab6f1fd2a1b17057c7c349bc2baf8921a85c0180cc7be8011/orjson-3.11.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7bbf9b333f1568ef5da42bc96e18bf30fd7f8d54e9ae066d711056add508e415", size = 129152, upload-time = "2025-10-24T15:49:13.754Z" }, - { url = "https://files.pythonhosted.org/packages/f7/ef/2811def7ce3d8576b19e3929fff8f8f0d44bc5eb2e0fdecb2e6e6cc6c720/orjson-3.11.4-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4806363144bb6e7297b8e95870e78d30a649fdc4e23fc84daa80c8ebd366ce44", size = 136834, upload-time = "2025-10-24T15:49:15.307Z" }, - { url = "https://files.pythonhosted.org/packages/00/d4/9aee9e54f1809cec8ed5abd9bc31e8a9631d19460e3b8470145d25140106/orjson-3.11.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad355e8308493f527d41154e9053b86a5be892b3b359a5c6d5d95cda23601cb2", size = 137519, upload-time = "2025-10-24T15:49:16.557Z" }, - { url = "https://files.pythonhosted.org/packages/db/ea/67bfdb5465d5679e8ae8d68c11753aaf4f47e3e7264bad66dc2f2249e643/orjson-3.11.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a7517482667fb9f0ff1b2f16fe5829296ed7a655d04d68cd9711a4d8a4e708", size = 136749, upload-time = "2025-10-24T15:49:17.796Z" }, - { url = "https://files.pythonhosted.org/packages/01/7e/62517dddcfce6d53a39543cd74d0dccfcbdf53967017c58af68822100272/orjson-3.11.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97eb5942c7395a171cbfecc4ef6701fc3c403e762194683772df4c54cfbb2210", size = 136325, upload-time = "2025-10-24T15:49:19.347Z" }, - { url = "https://files.pythonhosted.org/packages/18/ae/40516739f99ab4c7ec3aaa5cc242d341fcb03a45d89edeeaabc5f69cb2cf/orjson-3.11.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:149d95d5e018bdd822e3f38c103b1a7c91f88d38a88aada5c4e9b3a73a244241", size = 140204, upload-time = "2025-10-24T15:49:20.545Z" }, - { url = "https://files.pythonhosted.org/packages/82/18/ff5734365623a8916e3a4037fcef1cd1782bfc14cf0992afe7940c5320bf/orjson-3.11.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:624f3951181eb46fc47dea3d221554e98784c823e7069edb5dbd0dc826ac909b", size = 406242, upload-time = "2025-10-24T15:49:21.884Z" }, - { url = "https://files.pythonhosted.org/packages/e1/43/96436041f0a0c8c8deca6a05ebeaf529bf1de04839f93ac5e7c479807aec/orjson-3.11.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:03bfa548cf35e3f8b3a96c4e8e41f753c686ff3d8e182ce275b1751deddab58c", size = 150013, upload-time = "2025-10-24T15:49:23.185Z" }, - { url = "https://files.pythonhosted.org/packages/1b/48/78302d98423ed8780479a1e682b9aecb869e8404545d999d34fa486e573e/orjson-3.11.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:525021896afef44a68148f6ed8a8bf8375553d6066c7f48537657f64823565b9", size = 139951, upload-time = "2025-10-24T15:49:24.428Z" }, - { url = "https://files.pythonhosted.org/packages/4a/7b/ad613fdcdaa812f075ec0875143c3d37f8654457d2af17703905425981bf/orjson-3.11.4-cp312-cp312-win32.whl", hash = "sha256:b58430396687ce0f7d9eeb3dd47761ca7d8fda8e9eb92b3077a7a353a75efefa", size = 136049, upload-time = "2025-10-24T15:49:25.973Z" }, - { url = "https://files.pythonhosted.org/packages/b9/3c/9cf47c3ff5f39b8350fb21ba65d789b6a1129d4cbb3033ba36c8a9023520/orjson-3.11.4-cp312-cp312-win_amd64.whl", hash = "sha256:c6dbf422894e1e3c80a177133c0dda260f81428f9de16d61041949f6a2e5c140", size = 131461, upload-time = "2025-10-24T15:49:27.259Z" }, - { url = "https://files.pythonhosted.org/packages/c6/3b/e2425f61e5825dc5b08c2a5a2b3af387eaaca22a12b9c8c01504f8614c36/orjson-3.11.4-cp312-cp312-win_arm64.whl", hash = "sha256:d38d2bc06d6415852224fcc9c0bfa834c25431e466dc319f0edd56cca81aa96e", size = 126167, upload-time = "2025-10-24T15:49:28.511Z" }, + { url = "https://files.pythonhosted.org/packages/f3/fd/d6b0a36854179b93ed77839f107c4089d91cccc9f9ba1b752b6e3bac5f34/orjson-3.11.6-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e259e85a81d76d9665f03d6129e09e4435531870de5961ddcd0bf6e3a7fde7d7", size = 250029, upload-time = "2026-01-29T15:11:35.942Z" }, + { url = "https://files.pythonhosted.org/packages/a3/bb/22902619826641cf3b627c24aab62e2ad6b571bdd1d34733abb0dd57f67a/orjson-3.11.6-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:52263949f41b4a4822c6b1353bcc5ee2f7109d53a3b493501d3369d6d0e7937a", size = 134518, upload-time = "2026-01-29T15:11:37.347Z" }, + { url = "https://files.pythonhosted.org/packages/72/90/7a818da4bba1de711a9653c420749c0ac95ef8f8651cbc1dca551f462fe0/orjson-3.11.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6439e742fa7834a24698d358a27346bb203bff356ae0402e7f5df8f749c621a8", size = 137917, upload-time = "2026-01-29T15:11:38.511Z" }, + { url = "https://files.pythonhosted.org/packages/59/0f/02846c1cac8e205cb3822dd8aa8f9114acda216f41fd1999ace6b543418d/orjson-3.11.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b81ffd68f084b4e993e3867acb554a049fa7787cc8710bbcc1e26965580d99be", size = 134923, upload-time = "2026-01-29T15:11:39.711Z" }, + { url = "https://files.pythonhosted.org/packages/94/cf/aeaf683001b474bb3c3c757073a4231dfdfe8467fceaefa5bfd40902c99f/orjson-3.11.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a5a5468e5e60f7ef6d7f9044b06c8f94a3c56ba528c6e4f7f06ae95164b595ec", size = 140752, upload-time = "2026-01-29T15:11:41.347Z" }, + { url = "https://files.pythonhosted.org/packages/fc/fe/dad52d8315a65f084044a0819d74c4c9daf9ebe0681d30f525b0d29a31f0/orjson-3.11.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:72c5005eb45bd2535632d4f3bec7ad392832cfc46b62a3021da3b48a67734b45", size = 144201, upload-time = "2026-01-29T15:11:42.537Z" }, + { url = "https://files.pythonhosted.org/packages/36/bc/ab070dd421565b831801077f1e390c4d4af8bfcecafc110336680a33866b/orjson-3.11.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b14dd49f3462b014455a28a4d810d3549bf990567653eb43765cd847df09145", size = 142380, upload-time = "2026-01-29T15:11:44.309Z" }, + { url = "https://files.pythonhosted.org/packages/e6/d8/4b581c725c3a308717f28bf45a9fdac210bca08b67e8430143699413ff06/orjson-3.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e0bb2c1ea30ef302f0f89f9bf3e7f9ab5e2af29dc9f80eb87aa99788e4e2d65", size = 145582, upload-time = "2026-01-29T15:11:45.506Z" }, + { url = "https://files.pythonhosted.org/packages/5b/a2/09aab99b39f9a7f175ea8fa29adb9933a3d01e7d5d603cdee7f1c40c8da2/orjson-3.11.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:825e0a85d189533c6bff7e2fc417a28f6fcea53d27125c4551979aecd6c9a197", size = 147270, upload-time = "2026-01-29T15:11:46.782Z" }, + { url = "https://files.pythonhosted.org/packages/b8/2f/5ef8eaf7829dc50da3bf497c7775b21ee88437bc8c41f959aa3504ca6631/orjson-3.11.6-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:b04575417a26530637f6ab4b1f7b4f666eb0433491091da4de38611f97f2fcf3", size = 421222, upload-time = "2026-01-29T15:11:48.106Z" }, + { url = "https://files.pythonhosted.org/packages/3b/b0/dd6b941294c2b5b13da5fdc7e749e58d0c55a5114ab37497155e83050e95/orjson-3.11.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b83eb2e40e8c4da6d6b340ee6b1d6125f5195eb1b0ebb7eac23c6d9d4f92d224", size = 155562, upload-time = "2026-01-29T15:11:49.408Z" }, + { url = "https://files.pythonhosted.org/packages/8e/09/43924331a847476ae2f9a16bd6d3c9dab301265006212ba0d3d7fd58763a/orjson-3.11.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1f42da604ee65a6b87eef858c913ce3e5777872b19321d11e6fc6d21de89b64f", size = 147432, upload-time = "2026-01-29T15:11:50.635Z" }, + { url = "https://files.pythonhosted.org/packages/5d/e9/d9865961081816909f6b49d880749dbbd88425afd7c5bbce0549e2290d77/orjson-3.11.6-cp311-cp311-win32.whl", hash = "sha256:5ae45df804f2d344cffb36c43fdf03c82fb6cd247f5faa41e21891b40dfbf733", size = 139623, upload-time = "2026-01-29T15:11:51.82Z" }, + { url = "https://files.pythonhosted.org/packages/b4/f9/6836edb92f76eec1082919101eb1145d2f9c33c8f2c5e6fa399b82a2aaa8/orjson-3.11.6-cp311-cp311-win_amd64.whl", hash = "sha256:f4295948d65ace0a2d8f2c4ccc429668b7eb8af547578ec882e16bf79b0050b2", size = 136647, upload-time = "2026-01-29T15:11:53.454Z" }, + { url = "https://files.pythonhosted.org/packages/b3/0c/4954082eea948c9ae52ee0bcbaa2f99da3216a71bcc314ab129bde22e565/orjson-3.11.6-cp311-cp311-win_arm64.whl", hash = "sha256:314e9c45e0b81b547e3a1cfa3df3e07a815821b3dac9fe8cb75014071d0c16a4", size = 135327, upload-time = "2026-01-29T15:11:56.616Z" }, + { url = "https://files.pythonhosted.org/packages/14/ba/759f2879f41910b7e5e0cdbd9cf82a4f017c527fb0e972e9869ca7fe4c8e/orjson-3.11.6-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6f03f30cd8953f75f2a439070c743c7336d10ee940da918d71c6f3556af3ddcf", size = 249988, upload-time = "2026-01-29T15:11:58.294Z" }, + { url = "https://files.pythonhosted.org/packages/f0/70/54cecb929e6c8b10104fcf580b0cc7dc551aa193e83787dd6f3daba28bb5/orjson-3.11.6-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:af44baae65ef386ad971469a8557a0673bb042b0b9fd4397becd9c2dfaa02588", size = 134445, upload-time = "2026-01-29T15:11:59.819Z" }, + { url = "https://files.pythonhosted.org/packages/f2/6f/ec0309154457b9ba1ad05f11faa4441f76037152f75e1ac577db3ce7ca96/orjson-3.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c310a48542094e4f7dbb6ac076880994986dda8ca9186a58c3cb70a3514d3231", size = 137708, upload-time = "2026-01-29T15:12:01.488Z" }, + { url = "https://files.pythonhosted.org/packages/20/52/3c71b80840f8bab9cb26417302707b7716b7d25f863f3a541bcfa232fe6e/orjson-3.11.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d8dfa7a5d387f15ecad94cb6b2d2d5f4aeea64efd8d526bfc03c9812d01e1cc0", size = 134798, upload-time = "2026-01-29T15:12:02.705Z" }, + { url = "https://files.pythonhosted.org/packages/30/51/b490a43b22ff736282360bd02e6bded455cf31dfc3224e01cd39f919bbd2/orjson-3.11.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba8daee3e999411b50f8b50dbb0a3071dd1845f3f9a1a0a6fa6de86d1689d84d", size = 140839, upload-time = "2026-01-29T15:12:03.956Z" }, + { url = "https://files.pythonhosted.org/packages/95/bc/4bcfe4280c1bc63c5291bb96f98298845b6355da2226d3400e17e7b51e53/orjson-3.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f89d104c974eafd7436d7a5fdbc57f7a1e776789959a2f4f1b2eab5c62a339f4", size = 144080, upload-time = "2026-01-29T15:12:05.151Z" }, + { url = "https://files.pythonhosted.org/packages/01/74/22970f9ead9ab1f1b5f8c227a6c3aa8d71cd2c5acd005868a1d44f2362fa/orjson-3.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2e2e2456788ca5ea75616c40da06fc885a7dc0389780e8a41bf7c5389ba257b", size = 142435, upload-time = "2026-01-29T15:12:06.641Z" }, + { url = "https://files.pythonhosted.org/packages/29/34/d564aff85847ab92c82ee43a7a203683566c2fca0723a5f50aebbe759603/orjson-3.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a42efebc45afabb1448001e90458c4020d5c64fbac8a8dc4045b777db76cb5a", size = 145631, upload-time = "2026-01-29T15:12:08.351Z" }, + { url = "https://files.pythonhosted.org/packages/e7/ef/016957a3890752c4aa2368326ea69fa53cdc1fdae0a94a542b6410dbdf52/orjson-3.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:71b7cbef8471324966c3738c90ba38775563ef01b512feb5ad4805682188d1b9", size = 147058, upload-time = "2026-01-29T15:12:10.023Z" }, + { url = "https://files.pythonhosted.org/packages/56/cc/9a899c3972085645b3225569f91a30e221f441e5dc8126e6d060b971c252/orjson-3.11.6-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:f8515e5910f454fe9a8e13c2bb9dc4bae4c1836313e967e72eb8a4ad874f0248", size = 421161, upload-time = "2026-01-29T15:12:11.308Z" }, + { url = "https://files.pythonhosted.org/packages/21/a8/767d3fbd6d9b8fdee76974db40619399355fd49bf91a6dd2c4b6909ccf05/orjson-3.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:300360edf27c8c9bf7047345a94fddf3a8b8922df0ff69d71d854a170cb375cf", size = 155757, upload-time = "2026-01-29T15:12:12.776Z" }, + { url = "https://files.pythonhosted.org/packages/ad/0b/205cd69ac87e2272e13ef3f5f03a3d4657e317e38c1b08aaa2ef97060bbc/orjson-3.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:caaed4dad39e271adfadc106fab634d173b2bb23d9cf7e67bd645f879175ebfc", size = 147446, upload-time = "2026-01-29T15:12:14.166Z" }, + { url = "https://files.pythonhosted.org/packages/de/c5/dd9f22aa9f27c54c7d05cc32f4580c9ac9b6f13811eeb81d6c4c3f50d6b1/orjson-3.11.6-cp312-cp312-win32.whl", hash = "sha256:955368c11808c89793e847830e1b1007503a5923ddadc108547d3b77df761044", size = 139717, upload-time = "2026-01-29T15:12:15.7Z" }, + { url = "https://files.pythonhosted.org/packages/23/a1/e62fc50d904486970315a1654b8cfb5832eb46abb18cd5405118e7e1fc79/orjson-3.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:2c68de30131481150073d90a5d227a4a421982f42c025ecdfb66157f9579e06f", size = 136711, upload-time = "2026-01-29T15:12:17.055Z" }, + { url = "https://files.pythonhosted.org/packages/04/3d/b4fefad8bdf91e0fe212eb04975aeb36ea92997269d68857efcc7eb1dda3/orjson-3.11.6-cp312-cp312-win_arm64.whl", hash = "sha256:65dfa096f4e3a5e02834b681f539a87fbe85adc82001383c0db907557f666bfc", size = 135212, upload-time = "2026-01-29T15:12:18.3Z" }, ] [[package]] @@ -4557,15 +4578,14 @@ performance = [ [[package]] name = "pandas-stubs" -version = "2.2.3.250527" +version = "3.0.0.260204" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, - { name = "types-pytz" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5f/0d/5fe7f7f3596eb1c2526fea151e9470f86b379183d8b9debe44b2098651ca/pandas_stubs-2.2.3.250527.tar.gz", hash = "sha256:e2d694c4e72106055295ad143664e5c99e5815b07190d1ff85b73b13ff019e63", size = 106312, upload-time = "2025-05-27T15:24:29.716Z" } +sdist = { url = "https://files.pythonhosted.org/packages/27/1d/297ff2c7ea50a768a2247621d6451abb2a07c0e9be7ca6d36ebe371658e5/pandas_stubs-3.0.0.260204.tar.gz", hash = "sha256:bf9294b76352effcffa9cb85edf0bed1339a7ec0c30b8e1ac3d66b4228f1fbc3", size = 109383, upload-time = "2026-02-04T15:17:17.247Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/f8/46141ba8c9d7064dc5008bfb4a6ae5bd3c30e4c61c28b5c5ed485bf358ba/pandas_stubs-2.2.3.250527-py3-none-any.whl", hash = "sha256:cd0a49a95b8c5f944e605be711042a4dd8550e2c559b43d70ba2c4b524b66163", size = 159683, upload-time = "2025-05-27T15:24:28.4Z" }, + { url = "https://files.pythonhosted.org/packages/7c/2f/f91e4eee21585ff548e83358332d5632ee49f6b2dcd96cb5dca4e0468951/pandas_stubs-3.0.0.260204-py3-none-any.whl", hash = "sha256:5ab9e4d55a6e2752e9720828564af40d48c4f709e6a2c69b743014a6fcb6c241", size = 168540, upload-time = "2026-02-04T15:17:15.615Z" }, ] [[package]] @@ -4674,15 +4694,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] -[[package]] -name = "ply" -version = "3.11" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e5/69/882ee5c9d017149285cab114ebeab373308ef0f874fcdac9beb90e0ac4da/ply-3.11.tar.gz", hash = "sha256:00c7c1aaa88358b9c765b6d3000c6eec0ba42abca5351b095321aef446081da3", size = 159130, upload-time = "2018-02-15T19:01:31.097Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a3/58/35da89ee790598a0700ea49b2a66594140f44dec458c07e8e3d4979137fc/ply-3.11-py2.py3-none-any.whl", hash = "sha256:096f9b8350b65ebd2fd1346b12452efe5b9607f7482813ffca50c22722a807ce", size = 49567, upload-time = "2018-02-15T19:01:27.172Z" }, -] - [[package]] name = "polyfile-weave" version = "0.5.8" @@ -5078,11 +5089,11 @@ wheels = [ [[package]] name = "pyjwt" -version = "2.11.0" +version = "2.12.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a8/10/e8192be5f38f3e8e7e046716de4cae33d56fd5ae08927a823bb916be36c1/pyjwt-2.12.0.tar.gz", hash = "sha256:2f62390b667cd8257de560b850bb5a883102a388829274147f1d724453f8fb02", size = 102511, upload-time = "2026-03-12T17:15:30.831Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" }, + { url = "https://files.pythonhosted.org/packages/15/70/70f895f404d363d291dcf62c12c85fdd47619ad9674ac0f53364d035925a/pyjwt-2.12.0-py3-none-any.whl", hash = "sha256:9bb459d1bdd0387967d287f5656bf7ec2b9a26645d1961628cda1764e087fd6e", size = 29700, upload-time = "2026-03-12T17:15:29.257Z" }, ] [package.optional-dependencies] @@ -5246,43 +5257,45 @@ wheels = [ [[package]] name = "pytest" -version = "8.3.5" +version = "9.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "iniconfig" }, { name = "packaging" }, { name = "pluggy" }, + { name = "pygments" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891, upload-time = "2025-03-02T12:54:54.503Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634, upload-time = "2025-03-02T12:54:52.069Z" }, + { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] [[package]] name = "pytest-benchmark" -version = "4.0.0" +version = "5.2.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "py-cpuinfo" }, { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/28/08/e6b0067efa9a1f2a1eb3043ecd8a0c48bfeb60d3255006dcc829d72d5da2/pytest-benchmark-4.0.0.tar.gz", hash = "sha256:fb0785b83efe599a6a956361c0691ae1dbb5318018561af10f3e915caa0048d1", size = 334641, upload-time = "2022-10-25T21:21:55.686Z" } +sdist = { url = "https://files.pythonhosted.org/packages/24/34/9f732b76456d64faffbef6232f1f9dbec7a7c4999ff46282fa418bd1af66/pytest_benchmark-5.2.3.tar.gz", hash = "sha256:deb7317998a23c650fd4ff76e1230066a76cb45dcece0aca5607143c619e7779", size = 341340, upload-time = "2025-11-09T18:48:43.215Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/a1/3b70862b5b3f830f0422844f25a823d0470739d994466be9dbbbb414d85a/pytest_benchmark-4.0.0-py3-none-any.whl", hash = "sha256:fdb7db64e31c8b277dff9850d2a2556d8b60bcb0ea6524e36e28ffd7c87f71d6", size = 43951, upload-time = "2022-10-25T21:21:53.208Z" }, + { url = "https://files.pythonhosted.org/packages/33/29/e756e715a48959f1c0045342088d7ca9762a2f509b945f362a316e9412b7/pytest_benchmark-5.2.3-py3-none-any.whl", hash = "sha256:bc839726ad20e99aaa0d11a127445457b4219bdb9e80a1afc4b51da7f96b0803", size = 45255, upload-time = "2025-11-09T18:48:39.765Z" }, ] [[package]] name = "pytest-cov" -version = "4.1.0" +version = "7.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "coverage", extra = ["toml"] }, + { name = "pluggy" }, { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7a/15/da3df99fd551507694a9b01f512a2f6cf1254f33601605843c3775f39460/pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6", size = 63245, upload-time = "2023-05-24T18:44:56.845Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/f7/c933acc76f5208b3b00089573cf6a2bc26dc80a8aece8f52bb7d6b1855ca/pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1", size = 54328, upload-time = "2025-09-09T10:57:02.113Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/4b/8b78d126e275efa2379b1c2e09dc52cf70df16fc3b90613ef82531499d73/pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a", size = 21949, upload-time = "2023-05-24T18:44:54.079Z" }, + { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, ] [[package]] @@ -5299,14 +5312,14 @@ wheels = [ [[package]] name = "pytest-mock" -version = "3.14.1" +version = "3.15.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/71/28/67172c96ba684058a4d24ffe144d64783d2a270d0af0d9e792737bddc75c/pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e", size = 33241, upload-time = "2025-05-26T13:58:45.167Z" } +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/05/77b60e520511c53d1c1ca75f1930c7dd8e971d0c4379b7f4b3f9644685ba/pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0", size = 9923, upload-time = "2025-05-26T13:58:43.487Z" }, + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, ] [[package]] @@ -5818,28 +5831,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.14.6" +version = "0.15.5" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/52/f0/62b5a1a723fe183650109407fa56abb433b00aa1c0b9ba555f9c4efec2c6/ruff-0.14.6.tar.gz", hash = "sha256:6f0c742ca6a7783a736b867a263b9a7a80a45ce9bee391eeda296895f1b4e1cc", size = 5669501, upload-time = "2025-11-21T14:26:17.903Z" } +sdist = { url = "https://files.pythonhosted.org/packages/77/9b/840e0039e65fcf12758adf684d2289024d6140cde9268cc59887dc55189c/ruff-0.15.5.tar.gz", hash = "sha256:7c3601d3b6d76dce18c5c824fc8d06f4eef33d6df0c21ec7799510cde0f159a2", size = 4574214, upload-time = "2026-03-05T20:06:34.946Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/67/d2/7dd544116d107fffb24a0064d41a5d2ed1c9d6372d142f9ba108c8e39207/ruff-0.14.6-py3-none-linux_armv6l.whl", hash = "sha256:d724ac2f1c240dbd01a2ae98db5d1d9a5e1d9e96eba999d1c48e30062df578a3", size = 13326119, upload-time = "2025-11-21T14:25:24.2Z" }, - { url = "https://files.pythonhosted.org/packages/36/6a/ad66d0a3315d6327ed6b01f759d83df3c4d5f86c30462121024361137b6a/ruff-0.14.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:9f7539ea257aa4d07b7ce87aed580e485c40143f2473ff2f2b75aee003186004", size = 13526007, upload-time = "2025-11-21T14:25:26.906Z" }, - { url = "https://files.pythonhosted.org/packages/a3/9d/dae6db96df28e0a15dea8e986ee393af70fc97fd57669808728080529c37/ruff-0.14.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7f6007e55b90a2a7e93083ba48a9f23c3158c433591c33ee2e99a49b889c6332", size = 12676572, upload-time = "2025-11-21T14:25:29.826Z" }, - { url = "https://files.pythonhosted.org/packages/76/a4/f319e87759949062cfee1b26245048e92e2acce900ad3a909285f9db1859/ruff-0.14.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a8e7b9d73d8728b68f632aa8e824ef041d068d231d8dbc7808532d3629a6bef", size = 13140745, upload-time = "2025-11-21T14:25:32.788Z" }, - { url = "https://files.pythonhosted.org/packages/95/d3/248c1efc71a0a8ed4e8e10b4b2266845d7dfc7a0ab64354afe049eaa1310/ruff-0.14.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d50d45d4553a3ebcbd33e7c5e0fe6ca4aafd9a9122492de357205c2c48f00775", size = 13076486, upload-time = "2025-11-21T14:25:35.601Z" }, - { url = "https://files.pythonhosted.org/packages/a5/19/b68d4563fe50eba4b8c92aa842149bb56dd24d198389c0ed12e7faff4f7d/ruff-0.14.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:118548dd121f8a21bfa8ab2c5b80e5b4aed67ead4b7567790962554f38e598ce", size = 13727563, upload-time = "2025-11-21T14:25:38.514Z" }, - { url = "https://files.pythonhosted.org/packages/47/ac/943169436832d4b0e867235abbdb57ce3a82367b47e0280fa7b4eabb7593/ruff-0.14.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:57256efafbfefcb8748df9d1d766062f62b20150691021f8ab79e2d919f7c11f", size = 15199755, upload-time = "2025-11-21T14:25:41.516Z" }, - { url = "https://files.pythonhosted.org/packages/c9/b9/288bb2399860a36d4bb0541cb66cce3c0f4156aaff009dc8499be0c24bf2/ruff-0.14.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ff18134841e5c68f8e5df1999a64429a02d5549036b394fafbe410f886e1989d", size = 14850608, upload-time = "2025-11-21T14:25:44.428Z" }, - { url = "https://files.pythonhosted.org/packages/ee/b1/a0d549dd4364e240f37e7d2907e97ee80587480d98c7799d2d8dc7a2f605/ruff-0.14.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:29c4b7ec1e66a105d5c27bd57fa93203637d66a26d10ca9809dc7fc18ec58440", size = 14118754, upload-time = "2025-11-21T14:25:47.214Z" }, - { url = "https://files.pythonhosted.org/packages/13/ac/9b9fe63716af8bdfddfacd0882bc1586f29985d3b988b3c62ddce2e202c3/ruff-0.14.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:167843a6f78680746d7e226f255d920aeed5e4ad9c03258094a2d49d3028b105", size = 13949214, upload-time = "2025-11-21T14:25:50.002Z" }, - { url = "https://files.pythonhosted.org/packages/12/27/4dad6c6a77fede9560b7df6802b1b697e97e49ceabe1f12baf3ea20862e9/ruff-0.14.6-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:16a33af621c9c523b1ae006b1b99b159bf5ac7e4b1f20b85b2572455018e0821", size = 14106112, upload-time = "2025-11-21T14:25:52.841Z" }, - { url = "https://files.pythonhosted.org/packages/6a/db/23e322d7177873eaedea59a7932ca5084ec5b7e20cb30f341ab594130a71/ruff-0.14.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:1432ab6e1ae2dc565a7eea707d3b03a0c234ef401482a6f1621bc1f427c2ff55", size = 13035010, upload-time = "2025-11-21T14:25:55.536Z" }, - { url = "https://files.pythonhosted.org/packages/a8/9c/20e21d4d69dbb35e6a1df7691e02f363423658a20a2afacf2a2c011800dc/ruff-0.14.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4c55cfbbe7abb61eb914bfd20683d14cdfb38a6d56c6c66efa55ec6570ee4e71", size = 13054082, upload-time = "2025-11-21T14:25:58.625Z" }, - { url = "https://files.pythonhosted.org/packages/66/25/906ee6a0464c3125c8d673c589771a974965c2be1a1e28b5c3b96cb6ef88/ruff-0.14.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:efea3c0f21901a685fff4befda6d61a1bf4cb43de16da87e8226a281d614350b", size = 13303354, upload-time = "2025-11-21T14:26:01.816Z" }, - { url = "https://files.pythonhosted.org/packages/4c/58/60577569e198d56922b7ead07b465f559002b7b11d53f40937e95067ca1c/ruff-0.14.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:344d97172576d75dc6afc0e9243376dbe1668559c72de1864439c4fc95f78185", size = 14054487, upload-time = "2025-11-21T14:26:05.058Z" }, - { url = "https://files.pythonhosted.org/packages/67/0b/8e4e0639e4cc12547f41cb771b0b44ec8225b6b6a93393176d75fe6f7d40/ruff-0.14.6-py3-none-win32.whl", hash = "sha256:00169c0c8b85396516fdd9ce3446c7ca20c2a8f90a77aa945ba6b8f2bfe99e85", size = 13013361, upload-time = "2025-11-21T14:26:08.152Z" }, - { url = "https://files.pythonhosted.org/packages/fb/02/82240553b77fd1341f80ebb3eaae43ba011c7a91b4224a9f317d8e6591af/ruff-0.14.6-py3-none-win_amd64.whl", hash = "sha256:390e6480c5e3659f8a4c8d6a0373027820419ac14fa0d2713bd8e6c3e125b8b9", size = 14432087, upload-time = "2025-11-21T14:26:10.891Z" }, - { url = "https://files.pythonhosted.org/packages/a5/1f/93f9b0fad9470e4c829a5bb678da4012f0c710d09331b860ee555216f4ea/ruff-0.14.6-py3-none-win_arm64.whl", hash = "sha256:d43c81fbeae52cfa8728d8766bbf46ee4298c888072105815b392da70ca836b2", size = 13520930, upload-time = "2025-11-21T14:26:13.951Z" }, + { url = "https://files.pythonhosted.org/packages/47/20/5369c3ce21588c708bcbe517a8fbe1a8dfdb5dfd5137e14790b1da71612c/ruff-0.15.5-py3-none-linux_armv6l.whl", hash = "sha256:4ae44c42281f42e3b06b988e442d344a5b9b72450ff3c892e30d11b29a96a57c", size = 10478185, upload-time = "2026-03-05T20:06:29.093Z" }, + { url = "https://files.pythonhosted.org/packages/44/ed/e81dd668547da281e5dce710cf0bc60193f8d3d43833e8241d006720e42b/ruff-0.15.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6edd3792d408ebcf61adabc01822da687579a1a023f297618ac27a5b51ef0080", size = 10859201, upload-time = "2026-03-05T20:06:32.632Z" }, + { url = "https://files.pythonhosted.org/packages/c4/8f/533075f00aaf19b07c5cd6aa6e5d89424b06b3b3f4583bfa9c640a079059/ruff-0.15.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:89f463f7c8205a9f8dea9d658d59eff49db05f88f89cc3047fb1a02d9f344010", size = 10184752, upload-time = "2026-03-05T20:06:40.312Z" }, + { url = "https://files.pythonhosted.org/packages/66/0e/ba49e2c3fa0395b3152bad634c7432f7edfc509c133b8f4529053ff024fb/ruff-0.15.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba786a8295c6574c1116704cf0b9e6563de3432ac888d8f83685654fe528fd65", size = 10534857, upload-time = "2026-03-05T20:06:19.581Z" }, + { url = "https://files.pythonhosted.org/packages/59/71/39234440f27a226475a0659561adb0d784b4d247dfe7f43ffc12dd02e288/ruff-0.15.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fd4b801e57955fe9f02b31d20375ab3a5c4415f2e5105b79fb94cf2642c91440", size = 10309120, upload-time = "2026-03-05T20:06:00.435Z" }, + { url = "https://files.pythonhosted.org/packages/f5/87/4140aa86a93df032156982b726f4952aaec4a883bb98cb6ef73c347da253/ruff-0.15.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:391f7c73388f3d8c11b794dbbc2959a5b5afe66642c142a6effa90b45f6f5204", size = 11047428, upload-time = "2026-03-05T20:05:51.867Z" }, + { url = "https://files.pythonhosted.org/packages/5a/f7/4953e7e3287676f78fbe85e3a0ca414c5ca81237b7575bdadc00229ac240/ruff-0.15.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8dc18f30302e379fe1e998548b0f5e9f4dff907f52f73ad6da419ea9c19d66c8", size = 11914251, upload-time = "2026-03-05T20:06:22.887Z" }, + { url = "https://files.pythonhosted.org/packages/77/46/0f7c865c10cf896ccf5a939c3e84e1cfaeed608ff5249584799a74d33835/ruff-0.15.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1cc6e7f90087e2d27f98dc34ed1b3ab7c8f0d273cc5431415454e22c0bd2a681", size = 11333801, upload-time = "2026-03-05T20:05:57.168Z" }, + { url = "https://files.pythonhosted.org/packages/d3/01/a10fe54b653061585e655f5286c2662ebddb68831ed3eaebfb0eb08c0a16/ruff-0.15.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1cb7169f53c1ddb06e71a9aebd7e98fc0fea936b39afb36d8e86d36ecc2636a", size = 11206821, upload-time = "2026-03-05T20:06:03.441Z" }, + { url = "https://files.pythonhosted.org/packages/7a/0d/2132ceaf20c5e8699aa83da2706ecb5c5dcdf78b453f77edca7fb70f8a93/ruff-0.15.5-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9b037924500a31ee17389b5c8c4d88874cc6ea8e42f12e9c61a3d754ff72f1ca", size = 11133326, upload-time = "2026-03-05T20:06:25.655Z" }, + { url = "https://files.pythonhosted.org/packages/72/cb/2e5259a7eb2a0f87c08c0fe5bf5825a1e4b90883a52685524596bfc93072/ruff-0.15.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:65bb414e5b4eadd95a8c1e4804f6772bbe8995889f203a01f77ddf2d790929dd", size = 10510820, upload-time = "2026-03-05T20:06:37.79Z" }, + { url = "https://files.pythonhosted.org/packages/ff/20/b67ce78f9e6c59ffbdb5b4503d0090e749b5f2d31b599b554698a80d861c/ruff-0.15.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d20aa469ae3b57033519c559e9bc9cd9e782842e39be05b50e852c7c981fa01d", size = 10302395, upload-time = "2026-03-05T20:05:54.504Z" }, + { url = "https://files.pythonhosted.org/packages/5f/e5/719f1acccd31b720d477751558ed74e9c88134adcc377e5e886af89d3072/ruff-0.15.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:15388dd28c9161cdb8eda68993533acc870aa4e646a0a277aa166de9ad5a8752", size = 10754069, upload-time = "2026-03-05T20:06:06.422Z" }, + { url = "https://files.pythonhosted.org/packages/c3/9c/d1db14469e32d98f3ca27079dbd30b7b44dbb5317d06ab36718dee3baf03/ruff-0.15.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b30da330cbd03bed0c21420b6b953158f60c74c54c5f4c1dabbdf3a57bf355d2", size = 11304315, upload-time = "2026-03-05T20:06:10.867Z" }, + { url = "https://files.pythonhosted.org/packages/28/3a/950367aee7c69027f4f422059227b290ed780366b6aecee5de5039d50fa8/ruff-0.15.5-py3-none-win32.whl", hash = "sha256:732e5ee1f98ba5b3679029989a06ca39a950cced52143a0ea82a2102cb592b74", size = 10551676, upload-time = "2026-03-05T20:06:13.705Z" }, + { url = "https://files.pythonhosted.org/packages/b8/00/bf077a505b4e649bdd3c47ff8ec967735ce2544c8e4a43aba42ee9bf935d/ruff-0.15.5-py3-none-win_amd64.whl", hash = "sha256:821d41c5fa9e19117616c35eaa3f4b75046ec76c65e7ae20a333e9a8696bc7fe", size = 11678972, upload-time = "2026-03-05T20:06:45.379Z" }, + { url = "https://files.pythonhosted.org/packages/fe/4e/cd76eca6db6115604b7626668e891c9dd03330384082e33662fb0f113614/ruff-0.15.5-py3-none-win_arm64.whl", hash = "sha256:b498d1c60d2fe5c10c45ec3f698901065772730b411f164ae270bb6bfcc4740b", size = 10965572, upload-time = "2026-03-05T20:06:16.984Z" }, ] [[package]] @@ -6418,11 +6430,11 @@ wheels = [ [[package]] name = "types-cachetools" -version = "5.5.0.20240820" +version = "6.2.0.20251022" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c2/7e/ad6ba4a56b2a994e0f0a04a61a50466b60ee88a13d10a18c83ac14a66c61/types-cachetools-5.5.0.20240820.tar.gz", hash = "sha256:b888ab5c1a48116f7799cd5004b18474cd82b5463acb5ffb2db2fc9c7b053bc0", size = 4198, upload-time = "2024-08-20T02:30:07.525Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3b/a8/f9bcc7f1be63af43ef0170a773e2d88817bcc7c9d8769f2228c802826efe/types_cachetools-6.2.0.20251022.tar.gz", hash = "sha256:f1d3c736f0f741e89ec10f0e1b0138625023e21eb33603a930c149e0318c0cef", size = 9608, upload-time = "2025-10-22T03:03:58.16Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/27/4d/fd7cc050e2d236d5570c4d92531c0396573a1e14b31735870e849351c717/types_cachetools-5.5.0.20240820-py3-none-any.whl", hash = "sha256:efb2ed8bf27a4b9d3ed70d33849f536362603a90b8090a328acf0cd42fda82e2", size = 4149, upload-time = "2024-08-20T02:30:06.461Z" }, + { url = "https://files.pythonhosted.org/packages/98/2d/8d821ed80f6c2c5b427f650bf4dc25b80676ed63d03388e4b637d2557107/types_cachetools-6.2.0.20251022-py3-none-any.whl", hash = "sha256:698eb17b8f16b661b90624708b6915f33dbac2d185db499ed57e4997e7962cad", size = 9341, upload-time = "2025-10-22T03:03:57.036Z" }, ] [[package]] @@ -6457,32 +6469,32 @@ wheels = [ [[package]] name = "types-deprecated" -version = "1.2.15.20250304" +version = "1.3.1.20260130" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0e/67/eeefaaabb03b288aad85483d410452c8bbcbf8b2bd876b0e467ebd97415b/types_deprecated-1.2.15.20250304.tar.gz", hash = "sha256:c329030553029de5cc6cb30f269c11f4e00e598c4241290179f63cda7d33f719", size = 8015, upload-time = "2025-03-04T02:48:17.894Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/97/9924e496f88412788c432891cacd041e542425fe0bffff4143a7c1c89ac4/types_deprecated-1.3.1.20260130.tar.gz", hash = "sha256:726b05e5e66d42359b1d6631835b15de62702588c8a59b877aa4b1e138453450", size = 8455, upload-time = "2026-01-30T03:58:17.401Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/e3/c18aa72ab84e0bc127a3a94e93be1a6ac2cb281371d3a45376ab7cfdd31c/types_deprecated-1.2.15.20250304-py3-none-any.whl", hash = "sha256:86a65aa550ea8acf49f27e226b8953288cd851de887970fbbdf2239c116c3107", size = 8553, upload-time = "2025-03-04T02:48:16.666Z" }, + { url = "https://files.pythonhosted.org/packages/d2/b2/6f920582af7efcd37165cd6321707f3ad5839dd24565a8a982f2bd9c6fd1/types_deprecated-1.3.1.20260130-py3-none-any.whl", hash = "sha256:593934d85c38ca321a9d301f00c42ffe13e4cf830b71b10579185ba0ce172d9a", size = 9077, upload-time = "2026-01-30T03:58:16.633Z" }, ] [[package]] name = "types-docutils" -version = "0.21.0.20250809" +version = "0.22.3.20260223" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/be/9b/f92917b004e0a30068e024e8925c7d9b10440687b96d91f26d8762f4b68c/types_docutils-0.21.0.20250809.tar.gz", hash = "sha256:cc2453c87dc729b5aae499597496e4f69b44aa5fccb27051ed8bb55b0bd5e31b", size = 54770, upload-time = "2025-08-09T03:15:42.752Z" } +sdist = { url = "https://files.pythonhosted.org/packages/80/33/92c0129283363e3b3ba270bf6a2b7d077d949d2f90afc4abaf6e73578563/types_docutils-0.22.3.20260223.tar.gz", hash = "sha256:e90e868da82df615ea2217cf36dff31f09660daa15fc0f956af53f89c1364501", size = 57230, upload-time = "2026-02-23T04:11:21.806Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/a9/46bc12e4c918c4109b67401bf87fd450babdffbebd5dbd7833f5096f42a5/types_docutils-0.21.0.20250809-py3-none-any.whl", hash = "sha256:af02c82327e8ded85f57dd85c8ebf93b6a0b643d85a44c32d471e3395604ea50", size = 89598, upload-time = "2025-08-09T03:15:41.503Z" }, + { url = "https://files.pythonhosted.org/packages/ba/c7/a4ae6a75d5b07d63089d5c04d450a0de4a5d48ffcb84b95659b22d3885fe/types_docutils-0.22.3.20260223-py3-none-any.whl", hash = "sha256:cc2d6b7560a28e351903db0989091474aa619ad287843a018324baee9c4d9a8f", size = 91969, upload-time = "2026-02-23T04:11:20.966Z" }, ] [[package]] name = "types-flask-cors" -version = "5.0.0.20250413" +version = "6.0.0.20250809" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "flask" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a4/f3/dd2f0d274ecb77772d3ce83735f75ad14713461e8cf7e6d61a7c272037b1/types_flask_cors-5.0.0.20250413.tar.gz", hash = "sha256:b346d052f4ef3b606b73faf13e868e458f1efdbfedcbe1aba739eb2f54a6cf5f", size = 9921, upload-time = "2025-04-13T04:04:15.515Z" } +sdist = { url = "https://files.pythonhosted.org/packages/45/e0/e5dd841bf475765fb61cb04c1e70d2fd0675a0d4ddfacd50a333eafe7267/types_flask_cors-6.0.0.20250809.tar.gz", hash = "sha256:24380a2b82548634c0931d50b9aafab214eea9f85dcc04f15ab1518752a7e6aa", size = 9951, upload-time = "2025-08-09T03:16:37.454Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/66/34/7d64eb72d80bfd5b9e6dd31e7fe351a1c9a735f5c01e85b1d3b903a9d656/types_flask_cors-5.0.0.20250413-py3-none-any.whl", hash = "sha256:8183fdba764d45a5b40214468a1d5daa0e86c4ee6042d13f38cc428308f27a64", size = 9982, upload-time = "2025-04-13T04:04:14.27Z" }, + { url = "https://files.pythonhosted.org/packages/9f/5e/1e60c29eb5796233d4d627ca4979c4ae8da962fd0aae0cdb6e3e6a807bbc/types_flask_cors-6.0.0.20250809-py3-none-any.whl", hash = "sha256:f6d660dddab946779f4263cb561bffe275d86cb8747ce02e9fec8d340780131b", size = 9971, upload-time = "2025-08-09T03:16:36.593Z" }, ] [[package]] @@ -6543,14 +6555,14 @@ wheels = [ [[package]] name = "types-jsonschema" -version = "4.23.0.20250516" +version = "4.26.0.20260202" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "referencing" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a0/ec/27ea5bffdb306bf261f6677a98b6993d93893b2c2e30f7ecc1d2c99d32e7/types_jsonschema-4.23.0.20250516.tar.gz", hash = "sha256:9ace09d9d35c4390a7251ccd7d833b92ccc189d24d1b347f26212afce361117e", size = 14911, upload-time = "2025-05-16T03:09:33.728Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/07/68f63e715eb327ed2f5292e29e8be99785db0f72c7664d2c63bd4dbdc29d/types_jsonschema-4.26.0.20260202.tar.gz", hash = "sha256:29831baa4308865a9aec547a61797a06fc152b0dac8dddd531e002f32265cb07", size = 16168, upload-time = "2026-02-02T04:11:22.585Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/48/73ae8b388e19fc4a2a8060d0876325ec7310cfd09b53a2185186fd35959f/types_jsonschema-4.23.0.20250516-py3-none-any.whl", hash = "sha256:e7d0dd7db7e59e63c26e3230e26ffc64c4704cc5170dc21270b366a35ead1618", size = 15027, upload-time = "2025-05-16T03:09:32.499Z" }, + { url = "https://files.pythonhosted.org/packages/c1/06/962d4f364f779d7389cd31a1bb581907b057f52f0ace2c119a8dd8409db6/types_jsonschema-4.26.0.20260202-py3-none-any.whl", hash = "sha256:41c95343abc4de9264e333a55e95dfb4d401e463856d0164eec9cb182e8746da", size = 15914, upload-time = "2026-02-02T04:11:21.61Z" }, ] [[package]] @@ -6564,11 +6576,11 @@ wheels = [ [[package]] name = "types-oauthlib" -version = "3.2.0.20250516" +version = "3.3.0.20250822" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b1/2c/dba2c193ccff2d1e2835589d4075b230d5627b9db363e9c8de153261d6ec/types_oauthlib-3.2.0.20250516.tar.gz", hash = "sha256:56bf2cffdb8443ae718d4e83008e3fbd5f861230b4774e6d7799527758119d9a", size = 24683, upload-time = "2025-05-16T03:07:42.484Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/6e/d08033f562053c459322333c46baa8cf8d2d8c18f30d46dd898c8fd8df77/types_oauthlib-3.3.0.20250822.tar.gz", hash = "sha256:2cd41587dd80c199e4230e3f086777e9ae525e89579c64afe5e0039ab09be9de", size = 25700, upload-time = "2025-08-22T03:02:41.378Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/54/cdd62283338616fd2448f534b29110d79a42aaabffaf5f45e7aed365a366/types_oauthlib-3.2.0.20250516-py3-none-any.whl", hash = "sha256:5799235528bc9bd262827149a1633ff55ae6e5a5f5f151f4dae74359783a31b3", size = 45671, upload-time = "2025-05-16T03:07:41.268Z" }, + { url = "https://files.pythonhosted.org/packages/18/4b/00593b8b5d055550e1fcb9af2c42fa11b0a90bf16a94759a77bc1c3c0c72/types_oauthlib-3.3.0.20250822-py3-none-any.whl", hash = "sha256:b7f4c9b9eed0e020f454e0af800b10e93dd2efd196da65744b76910cce7e70d6", size = 48800, upload-time = "2025-08-22T03:02:40.427Z" }, ] [[package]] @@ -6609,11 +6621,11 @@ wheels = [ [[package]] name = "types-protobuf" -version = "5.29.1.20250403" +version = "6.32.1.20260221" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/78/6d/62a2e73b966c77609560800004dd49a926920dd4976a9fdd86cf998e7048/types_protobuf-5.29.1.20250403.tar.gz", hash = "sha256:7ff44f15022119c9d7558ce16e78b2d485bf7040b4fadced4dd069bb5faf77a2", size = 59413, upload-time = "2025-04-02T10:07:17.138Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5f/e2/9aa4a3b2469508bd7b4e2ae11cbedaf419222a09a1b94daffcd5efca4023/types_protobuf-6.32.1.20260221.tar.gz", hash = "sha256:6d5fb060a616bfb076cbb61b4b3c3969f5fc8bec5810f9a2f7e648ee5cbcbf6e", size = 64408, upload-time = "2026-02-21T03:55:13.916Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/69/e3/b74dcc2797b21b39d5a4f08a8b08e20369b4ca250d718df7af41a60dd9f0/types_protobuf-5.29.1.20250403-py3-none-any.whl", hash = "sha256:c71de04106a2d54e5b2173d0a422058fae0ef2d058d70cf369fb797bf61ffa59", size = 73874, upload-time = "2025-04-02T10:07:15.755Z" }, + { url = "https://files.pythonhosted.org/packages/2e/e8/1fd38926f9cf031188fbc5a96694203ea6f24b0e34bd64a225ec6f6291ba/types_protobuf-6.32.1.20260221-py3-none-any.whl", hash = "sha256:da7cdd947975964a93c30bfbcc2c6841ee646b318d3816b033adc2c4eb6448e4", size = 77956, upload-time = "2026-02-21T03:55:12.894Z" }, ] [[package]] @@ -6686,22 +6698,13 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/85/4f/b88274658cf489e35175be8571c970e9a1219713bafd8fc9e166d7351ecb/types_python_http_client-3.3.7.20250708-py3-none-any.whl", hash = "sha256:e2fc253859decab36713d82fc7f205868c3ddeaee79dbb55956ad9ca77abe12b", size = 8890, upload-time = "2025-07-08T03:14:35.506Z" }, ] -[[package]] -name = "types-pytz" -version = "2025.2.0.20251108" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/40/ff/c047ddc68c803b46470a357454ef76f4acd8c1088f5cc4891cdd909bfcf6/types_pytz-2025.2.0.20251108.tar.gz", hash = "sha256:fca87917836ae843f07129567b74c1929f1870610681b4c92cb86a3df5817bdb", size = 10961, upload-time = "2025-11-08T02:55:57.001Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e7/c1/56ef16bf5dcd255155cc736d276efa6ae0a5c26fd685e28f0412a4013c01/types_pytz-2025.2.0.20251108-py3-none-any.whl", hash = "sha256:0f1c9792cab4eb0e46c52f8845c8f77cf1e313cb3d68bf826aa867fe4717d91c", size = 10116, upload-time = "2025-11-08T02:55:56.194Z" }, -] - [[package]] name = "types-pywin32" -version = "310.0.0.20250516" +version = "311.0.0.20251008" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6c/bc/c7be2934a37cc8c645c945ca88450b541e482c4df3ac51e5556377d34811/types_pywin32-310.0.0.20250516.tar.gz", hash = "sha256:91e5bfc033f65c9efb443722eff8101e31d690dd9a540fa77525590d3da9cc9d", size = 328459, upload-time = "2025-05-16T03:07:57.411Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1a/05/cd94300066241a7abb52238f0dd8d7f4fe1877cf2c72bd1860856604d962/types_pywin32-311.0.0.20251008.tar.gz", hash = "sha256:d6d4faf8e0d7fdc0e0a1ff297b80be07d6d18510f102d793bf54e9e3e86f6d06", size = 329561, upload-time = "2025-10-08T02:51:39.436Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9b/72/469e4cc32399dbe6c843e38fdb6d04fee755e984e137c0da502f74d3ac59/types_pywin32-310.0.0.20250516-py3-none-any.whl", hash = "sha256:f9ef83a1ec3e5aae2b0e24c5f55ab41272b5dfeaabb9a0451d33684c9545e41a", size = 390411, upload-time = "2025-05-16T03:07:56.282Z" }, + { url = "https://files.pythonhosted.org/packages/af/08/00a38e6b71585e6741d5b3b4cc9dd165cf549b6f1ed78815c6585f8b1b58/types_pywin32-311.0.0.20251008-py3-none-any.whl", hash = "sha256:775e1046e0bad6d29ca47501301cce67002f6661b9cebbeca93f9c388c53fab4", size = 392942, upload-time = "2025-10-08T02:51:38.327Z" }, ] [[package]] @@ -6728,11 +6731,11 @@ wheels = [ [[package]] name = "types-regex" -version = "2024.11.6.20250403" +version = "2026.2.28.20260301" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c7/75/012b90c8557d3abb3b58a9073a94d211c8f75c9b2e26bf0d8af7ecf7bc78/types_regex-2024.11.6.20250403.tar.gz", hash = "sha256:3fdf2a70bbf830de4b3a28e9649a52d43dabb57cdb18fbfe2252eefb53666665", size = 12394, upload-time = "2025-04-03T02:54:35.379Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3a/ed/106958cb686316113b748ed4209fa363fd92b15759d5409c3930fed36606/types_regex-2026.2.28.20260301.tar.gz", hash = "sha256:644c231db3f368908320170c14905731a7ae5fabdac0f60f5d6d12ecdd3bc8dd", size = 13157, upload-time = "2026-03-01T04:11:13.559Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/61/49/67200c4708f557be6aa4ecdb1fa212d67a10558c5240251efdc799cca22f/types_regex-2024.11.6.20250403-py3-none-any.whl", hash = "sha256:e22c0f67d73f4b4af6086a340f387b6f7d03bed8a0bb306224b75c51a29b0001", size = 10396, upload-time = "2025-04-03T02:54:34.555Z" }, + { url = "https://files.pythonhosted.org/packages/c7/bb/9bc26fcf5155bd25efeca35f8ba6bffb8b3c9da2baac8bf40067606418f3/types_regex-2026.2.28.20260301-py3-none-any.whl", hash = "sha256:7da7a1fe67528238176a5844fd435ca90617cf605341308686afbc579fdea5c0", size = 11130, upload-time = "2026-03-01T04:11:11.454Z" }, ] [[package]] diff --git a/dev/pytest/pytest_config_tests.py b/dev/pytest/pytest_config_tests.py index 1ec95deb09..1ae115f85c 100644 --- a/dev/pytest/pytest_config_tests.py +++ b/dev/pytest/pytest_config_tests.py @@ -38,7 +38,6 @@ BASE_API_AND_DOCKER_CONFIG_SET_DIFF = { "UPSTASH_VECTOR_URL", "USING_UGC_INDEX", "WEAVIATE_BATCH_SIZE", - "WEAVIATE_GRPC_ENABLED", } BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF = { @@ -86,7 +85,6 @@ BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF = { "VIKINGDB_CONNECTION_TIMEOUT", "VIKINGDB_SOCKET_TIMEOUT", "WEAVIATE_BATCH_SIZE", - "WEAVIATE_GRPC_ENABLED", } API_CONFIG_SET = set(dotenv_values(Path("api") / Path(".env.example")).keys()) diff --git a/web/.env.example b/web/.env.example index df4e725c51..ed06ebe2c9 100644 --- a/web/.env.example +++ b/web/.env.example @@ -12,6 +12,11 @@ NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api # console or api domain. # example: http://udify.app/api NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api +# Dev-only Hono proxy targets. The frontend keeps requesting http://localhost:5001 directly. +HONO_PROXY_HOST=127.0.0.1 +HONO_PROXY_PORT=5001 +HONO_CONSOLE_API_PROXY_TARGET= +HONO_PUBLIC_API_PROXY_TARGET= # When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. NEXT_PUBLIC_COOKIE_DOMAIN= diff --git a/web/__tests__/component-coverage-filters.test.ts b/web/__tests__/component-coverage-filters.test.ts new file mode 100644 index 0000000000..cacc1e2142 --- /dev/null +++ b/web/__tests__/component-coverage-filters.test.ts @@ -0,0 +1,115 @@ +import fs from 'node:fs' +import os from 'node:os' +import path from 'node:path' +import { afterEach, describe, expect, it } from 'vitest' +import { + collectComponentCoverageExcludedFiles, + COMPONENT_COVERAGE_EXCLUDE_LABEL, + getComponentCoverageExclusionReasons, +} from '../scripts/component-coverage-filters.mjs' + +describe('component coverage filters', () => { + describe('getComponentCoverageExclusionReasons', () => { + it('should exclude type-only files by basename', () => { + expect( + getComponentCoverageExclusionReasons( + 'web/app/components/share/text-generation/types.ts', + 'export type ShareMode = "run-once" | "run-batch"', + ), + ).toContain('type-only') + }) + + it('should exclude pure barrel files', () => { + expect( + getComponentCoverageExclusionReasons( + 'web/app/components/base/amplitude/index.ts', + [ + 'export { default } from "./AmplitudeProvider"', + 'export { resetUser, trackEvent } from "./utils"', + ].join('\n'), + ), + ).toContain('pure-barrel') + }) + + it('should exclude generated files from marker comments', () => { + expect( + getComponentCoverageExclusionReasons( + 'web/app/components/base/icons/src/vender/workflow/Answer.tsx', + [ + '// GENERATE BY script', + '// DON NOT EDIT IT MANUALLY', + 'export default function Icon() {', + ' return null', + '}', + ].join('\n'), + ), + ).toContain('generated') + }) + + it('should exclude pure static files with exported constants only', () => { + expect( + getComponentCoverageExclusionReasons( + 'web/app/components/workflow/note-node/constants.ts', + [ + 'import { NoteTheme } from "./types"', + 'export const CUSTOM_NOTE_NODE = "custom-note"', + 'export const THEME_MAP = {', + ' [NoteTheme.blue]: { title: "bg-blue-100" },', + '}', + ].join('\n'), + ), + ).toContain('pure-static') + }) + + it('should keep runtime logic files tracked', () => { + expect( + getComponentCoverageExclusionReasons( + 'web/app/components/workflow/nodes/trigger-schedule/default.ts', + [ + 'const validate = (value: string) => value.trim()', + 'export const nodeDefault = {', + ' value: validate("x"),', + '}', + ].join('\n'), + ), + ).toEqual([]) + }) + }) + + describe('collectComponentCoverageExcludedFiles', () => { + const tempDirs: string[] = [] + + afterEach(() => { + for (const dir of tempDirs) + fs.rmSync(dir, { recursive: true, force: true }) + tempDirs.length = 0 + }) + + it('should collect excluded files for coverage config and keep runtime files out', () => { + const rootDir = fs.mkdtempSync(path.join(os.tmpdir(), 'component-coverage-filters-')) + tempDirs.push(rootDir) + + fs.mkdirSync(path.join(rootDir, 'barrel'), { recursive: true }) + fs.mkdirSync(path.join(rootDir, 'icons'), { recursive: true }) + fs.mkdirSync(path.join(rootDir, 'static'), { recursive: true }) + fs.mkdirSync(path.join(rootDir, 'runtime'), { recursive: true }) + + fs.writeFileSync(path.join(rootDir, 'barrel', 'index.ts'), 'export { default } from "./Button"\n') + fs.writeFileSync(path.join(rootDir, 'icons', 'generated-icon.tsx'), '// @generated\nexport default function Icon() { return null }\n') + fs.writeFileSync(path.join(rootDir, 'static', 'constants.ts'), 'export const COLORS = { primary: "#fff" }\n') + fs.writeFileSync(path.join(rootDir, 'runtime', 'config.ts'), 'export const config = makeConfig()\n') + fs.writeFileSync(path.join(rootDir, 'runtime', 'types.ts'), 'export type Config = { value: string }\n') + + expect(collectComponentCoverageExcludedFiles(rootDir, { pathPrefix: 'app/components' })).toEqual([ + 'app/components/barrel/index.ts', + 'app/components/icons/generated-icon.tsx', + 'app/components/runtime/types.ts', + 'app/components/static/constants.ts', + ]) + }) + }) + + it('should describe the excluded coverage categories', () => { + expect(COMPONENT_COVERAGE_EXCLUDE_LABEL).toBe('type-only files, pure barrel files, generated files, pure static files') + }) +}) diff --git a/web/__tests__/share/text-generation-index-flow.test.tsx b/web/__tests__/share/text-generation-index-flow.test.tsx new file mode 100644 index 0000000000..3292474bec --- /dev/null +++ b/web/__tests__/share/text-generation-index-flow.test.tsx @@ -0,0 +1,235 @@ +import type { AccessMode } from '@/models/access-control' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import * as React from 'react' +import TextGeneration from '@/app/components/share/text-generation' + +const useSearchParamsMock = vi.fn(() => new URLSearchParams()) + +vi.mock('next/navigation', () => ({ + useSearchParams: () => useSearchParamsMock(), +})) + +vi.mock('@/hooks/use-breakpoints', () => ({ + default: vi.fn(() => 'pc'), + MediaType: { pc: 'pc', pad: 'pad', mobile: 'mobile' }, +})) + +vi.mock('@/hooks/use-app-favicon', () => ({ + useAppFavicon: vi.fn(), +})) + +vi.mock('@/hooks/use-document-title', () => ({ + default: vi.fn(), +})) + +vi.mock('@/i18n-config/client', () => ({ + changeLanguage: vi.fn(() => Promise.resolve()), +})) + +vi.mock('@/app/components/share/text-generation/run-once', () => ({ + default: ({ + inputs, + onInputsChange, + onSend, + runControl, + }: { + inputs: Record + onInputsChange: (inputs: Record) => void + onSend: () => void + runControl?: { isStopping: boolean } | null + }) => ( +
+ {String(inputs.name ?? '')} + + + {runControl ? 'stop-ready' : 'idle'} +
+ ), +})) + +vi.mock('@/app/components/share/text-generation/run-batch', () => ({ + default: ({ onSend }: { onSend: (data: string[][]) => void }) => ( + + ), +})) + +vi.mock('@/app/components/app/text-generate/saved-items', () => ({ + default: ({ list }: { list: { id: string }[] }) =>
{list.length}
, +})) + +vi.mock('@/app/components/share/text-generation/menu-dropdown', () => ({ + default: () =>
, +})) + +vi.mock('@/app/components/share/text-generation/result', () => { + const MockResult = ({ + isCallBatchAPI, + onRunControlChange, + onRunStart, + taskId, + }: { + isCallBatchAPI: boolean + onRunControlChange?: (control: { onStop: () => void, isStopping: boolean } | null) => void + onRunStart: () => void + taskId?: number + }) => { + const runControlRef = React.useRef(false) + + React.useEffect(() => { + onRunStart() + }, [onRunStart]) + + React.useEffect(() => { + if (!isCallBatchAPI && !runControlRef.current) { + runControlRef.current = true + onRunControlChange?.({ onStop: vi.fn(), isStopping: false }) + } + }, [isCallBatchAPI, onRunControlChange]) + + return
+ } + + return { + default: MockResult, + } +}) + +const fetchSavedMessageMock = vi.fn() + +vi.mock('@/service/share', async () => { + const actual = await vi.importActual('@/service/share') + return { + ...actual, + fetchSavedMessage: (...args: Parameters) => fetchSavedMessageMock(...args), + removeMessage: vi.fn(), + saveMessage: vi.fn(), + } +}) + +const mockSystemFeatures = { + branding: { + enabled: false, + workspace_logo: null, + }, +} + +const mockWebAppState = { + appInfo: { + app_id: 'app-123', + site: { + title: 'Text Generation', + description: 'Share description', + default_language: 'en-US', + icon_type: 'emoji', + icon: 'robot', + icon_background: '#fff', + icon_url: '', + }, + custom_config: { + remove_webapp_brand: false, + replace_webapp_logo: '', + }, + }, + appParams: { + user_input_form: [ + { + 'text-input': { + label: 'Name', + variable: 'name', + required: true, + max_length: 48, + default: '', + hide: false, + }, + }, + ], + more_like_this: { + enabled: true, + }, + file_upload: { + enabled: false, + number_limits: 2, + detail: 'low', + allowed_upload_methods: ['local_file'], + }, + text_to_speech: { + enabled: true, + }, + system_parameters: { + image_file_size_limit: 10, + }, + }, + webAppAccessMode: 'public' as AccessMode, +} + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (state: { systemFeatures: typeof mockSystemFeatures }) => unknown) => + selector({ systemFeatures: mockSystemFeatures }), +})) + +vi.mock('@/context/web-app-context', () => ({ + useWebAppStore: (selector: (state: typeof mockWebAppState) => unknown) => selector(mockWebAppState), +})) + +describe('TextGeneration', () => { + beforeEach(() => { + vi.clearAllMocks() + useSearchParamsMock.mockReturnValue(new URLSearchParams()) + fetchSavedMessageMock.mockResolvedValue({ + data: [{ id: 'saved-1' }, { id: 'saved-2' }], + }) + }) + + it('should switch between create, batch, and saved tabs after app state loads', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('run-once-mock')).toBeInTheDocument() + }) + expect(screen.getByTestId('run-once-input-name')).toHaveTextContent('') + + fireEvent.click(screen.getByRole('button', { name: 'change-inputs' })) + await waitFor(() => { + expect(screen.getByTestId('run-once-input-name')).toHaveTextContent('Gamma') + }) + + fireEvent.click(screen.getByTestId('tab-header-item-batch')) + expect(screen.getByRole('button', { name: 'run-batch' })).toBeInTheDocument() + + fireEvent.click(screen.getByTestId('tab-header-item-saved')) + expect(screen.getByTestId('saved-items-mock')).toHaveTextContent('2') + + fireEvent.click(screen.getByTestId('tab-header-item-create')) + expect(screen.getByTestId('run-once-mock')).toBeInTheDocument() + }) + + it('should wire single-run stop control and clear it when batch execution starts', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('run-once-mock')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByRole('button', { name: 'run-once' })) + await waitFor(() => { + expect(screen.getByText('stop-ready')).toBeInTheDocument() + }) + expect(screen.getByTestId('result-single')).toBeInTheDocument() + + fireEvent.click(screen.getByTestId('tab-header-item-batch')) + fireEvent.click(screen.getByRole('button', { name: 'run-batch' })) + await waitFor(() => { + expect(screen.getByText('idle')).toBeInTheDocument() + }) + expect(screen.getByTestId('result-task-1')).toBeInTheDocument() + expect(screen.getByTestId('result-task-2')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/configuration/tools/index.tsx b/web/app/components/app/configuration/tools/index.tsx index f348a7718d..51a9e87a97 100644 --- a/web/app/components/app/configuration/tools/index.tsx +++ b/web/app/components/app/configuration/tools/index.tsx @@ -179,7 +179,7 @@ const Tools = () => {
handleSaveExternalDataToolModal({ ...item, enabled }, index)} /> diff --git a/web/app/components/app/in-site-message/index.spec.tsx b/web/app/components/app/in-site-message/index.spec.tsx index 69f036da17..530084074d 100644 --- a/web/app/components/app/in-site-message/index.spec.tsx +++ b/web/app/components/app/in-site-message/index.spec.tsx @@ -1,7 +1,13 @@ +import type { ComponentProps } from 'react' import type { InSiteMessageActionItem } from './index' import { fireEvent, render, screen } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import InSiteMessage from './index' +vi.mock('@/app/components/base/amplitude', () => ({ + trackEvent: vi.fn(), +})) + describe('InSiteMessage', () => { const originalLocation = window.location @@ -18,9 +24,10 @@ describe('InSiteMessage', () => { vi.unstubAllGlobals() }) - const renderComponent = (actions: InSiteMessageActionItem[], props?: Partial>) => { + const renderComponent = (actions: InSiteMessageActionItem[], props?: Partial>) => { return render( { describe('Rendering', () => { it('should render title, subtitle, markdown content, and action buttons', () => { const actions: InSiteMessageActionItem[] = [ - { action: 'close', text: 'Close', type: 'default' }, - { action: 'link', text: 'Learn more', type: 'primary', data: 'https://example.com' }, + { action: 'close', action_name: 'dismiss', text: 'Close', type: 'default' }, + { action: 'link', action_name: 'learn_more', text: 'Learn more', type: 'primary', data: 'https://example.com' }, ] renderComponent(actions, { className: 'custom-message' }) @@ -56,7 +63,7 @@ describe('InSiteMessage', () => { }) it('should fallback to default header background when headerBgUrl is empty string', () => { - const actions: InSiteMessageActionItem[] = [{ action: 'close', text: 'Close', type: 'default' }] + const actions: InSiteMessageActionItem[] = [{ action: 'close', action_name: 'dismiss', text: 'Close', type: 'default' }] const { container } = renderComponent(actions, { headerBgUrl: '' }) const header = container.querySelector('div[style]') @@ -68,7 +75,7 @@ describe('InSiteMessage', () => { describe('Actions', () => { it('should call onAction and hide component when close action is clicked', () => { const onAction = vi.fn() - const closeAction: InSiteMessageActionItem = { action: 'close', text: 'Close', type: 'default' } + const closeAction: InSiteMessageActionItem = { action: 'close', action_name: 'dismiss', text: 'Close', type: 'default' } renderComponent([closeAction], { onAction }) fireEvent.click(screen.getByRole('button', { name: 'Close' })) @@ -80,6 +87,7 @@ describe('InSiteMessage', () => { it('should open a new tab when link action data is a string', () => { const linkAction: InSiteMessageActionItem = { action: 'link', + action_name: 'confirm', text: 'Open link', type: 'primary', data: 'https://example.com', @@ -103,6 +111,7 @@ describe('InSiteMessage', () => { const linkAction: InSiteMessageActionItem = { action: 'link', + action_name: 'confirm', text: 'Open self', type: 'primary', data: { href: 'https://example.com/self', target: '_self' }, @@ -118,6 +127,7 @@ describe('InSiteMessage', () => { it('should not trigger navigation when link data is invalid', () => { const linkAction: InSiteMessageActionItem = { action: 'link', + action_name: 'confirm', text: 'Broken link', type: 'primary', data: { rel: 'noopener' }, diff --git a/web/app/components/app/in-site-message/index.tsx b/web/app/components/app/in-site-message/index.tsx index 9225eb8a15..0276257860 100644 --- a/web/app/components/app/in-site-message/index.tsx +++ b/web/app/components/app/in-site-message/index.tsx @@ -1,6 +1,7 @@ 'use client' -import { useMemo, useState } from 'react' +import { useEffect, useMemo, useState } from 'react' +import { trackEvent } from '@/app/components/base/amplitude' import Button from '@/app/components/base/button' import { MarkdownWithDirective } from '@/app/components/base/markdown-with-directive' import { cn } from '@/utils/classnames' @@ -10,12 +11,14 @@ type InSiteMessageButtonType = 'primary' | 'default' export type InSiteMessageActionItem = { action: InSiteMessageAction + action_name: string // for tracing and analytics data?: unknown text: string type: InSiteMessageButtonType } type InSiteMessageProps = { + notificationId: string actions: InSiteMessageActionItem[] className?: string headerBgUrl?: string @@ -52,6 +55,7 @@ function normalizeLinkData(data: unknown): { href: string, rel?: string, target? const DEFAULT_HEADER_BG_URL = '/in-site-message/header-bg.svg' function InSiteMessage({ + notificationId, actions, className, headerBgUrl = DEFAULT_HEADER_BG_URL, @@ -70,7 +74,17 @@ function InSiteMessage({ } }, [headerBgUrl]) + useEffect(() => { + trackEvent('in_site_message_show', { + notification_id: notificationId, + }) + }, [notificationId]) + const handleAction = (item: InSiteMessageActionItem) => { + trackEvent('in_site_message_action', { + notification_id: notificationId, + action: item.action_name, + }) onAction?.(item) if (item.action === 'close') { diff --git a/web/app/components/app/in-site-message/notification.spec.tsx b/web/app/components/app/in-site-message/notification.spec.tsx index 84fe3aebc7..0d86d8a91c 100644 --- a/web/app/components/app/in-site-message/notification.spec.tsx +++ b/web/app/components/app/in-site-message/notification.spec.tsx @@ -15,11 +15,16 @@ const { mockNotificationDismiss: vi.fn(), })) -vi.mock('@/config', () => ({ - get IS_CLOUD_EDITION() { - return mockConfig.isCloudEdition - }, -})) +vi.mock(import('@/config'), async (importOriginal) => { + const actual = await importOriginal() + + return { + ...actual, + get IS_CLOUD_EDITION() { + return mockConfig.isCloudEdition + }, + } +}) vi.mock('@/service/client', () => ({ consoleQuery: { diff --git a/web/app/components/app/in-site-message/notification.tsx b/web/app/components/app/in-site-message/notification.tsx index de256a4663..cebf6ffd91 100644 --- a/web/app/components/app/in-site-message/notification.tsx +++ b/web/app/components/app/in-site-message/notification.tsx @@ -75,6 +75,7 @@ function InSiteMessageNotification() { const fallbackActions: InSiteMessageActionItem[] = [ { type: 'default', + action_name: 'dismiss', text: t('operation.close', { ns: 'common' }), action: 'close', }, @@ -96,6 +97,7 @@ function InSiteMessageNotification() { return ( { }) }) + afterEach(() => { + vi.useRealTimers() + }) + it('should render the modal and expose the expanded settings section', async () => { renderSettingsModal() expect(screen.getByText('appOverview.overview.appInfo.settings.title')).toBeInTheDocument() @@ -212,4 +216,54 @@ describe('SettingsModal', () => { })) expect(mockOnClose).toHaveBeenCalled() }) + + it('should clear the delayed hide-more timer when the modal unmounts after closing', () => { + vi.useFakeTimers() + const clearTimeoutSpy = vi.spyOn(globalThis, 'clearTimeout') + const { unmount } = renderSettingsModal() + + fireEvent.click(screen.getByText('appOverview.overview.appInfo.settings.more.entry')) + fireEvent.click(screen.getByText('common.operation.cancel')) + unmount() + + expect(clearTimeoutSpy).toHaveBeenCalled() + vi.runAllTimers() + }) + + it('should replace the pending hide-more timer and clear the ref after the timeout completes', async () => { + const hideCallbacks: Array<() => void> = [] + const originalSetTimeout = globalThis.setTimeout + const setTimeoutSpy = vi.spyOn(globalThis, 'setTimeout').mockImplementation((( + callback: TimerHandler, + delay?: number, + ...args: unknown[] + ) => { + if (delay === 200) { + hideCallbacks.push(() => { + if (typeof callback === 'function') + callback(...args) + }) + return hideCallbacks.length as unknown as ReturnType + } + + return originalSetTimeout(callback, delay, ...args) + }) as unknown as typeof setTimeout) + const clearTimeoutSpy = vi.spyOn(globalThis, 'clearTimeout') + renderSettingsModal() + + act(() => { + fireEvent.click(screen.getByText('common.operation.cancel')) + fireEvent.click(screen.getByText('common.operation.cancel')) + }) + + expect(clearTimeoutSpy).toHaveBeenCalled() + expect(hideCallbacks.length).toBeGreaterThanOrEqual(2) + + act(() => { + hideCallbacks.at(-1)?.() + }) + + setTimeoutSpy.mockRestore() + clearTimeoutSpy.mockRestore() + }) }) diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 92bfdc5d31..f7c9e309ab 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -6,7 +6,7 @@ import type { AppIconType, AppSSO, Language } from '@/types/app' import { RiArrowRightSLine, RiCloseLine } from '@remixicon/react' import Link from 'next/link' import * as React from 'react' -import { useCallback, useEffect, useState } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' import { Trans, useTranslation } from 'react-i18next' import ActionButton from '@/app/components/base/action-button' import AppIcon from '@/app/components/base/app-icon' @@ -99,6 +99,7 @@ const SettingsModal: FC = ({ const [language, setLanguage] = useState(default_language) const [saveLoading, setSaveLoading] = useState(false) const { t } = useTranslation() + const hideMoreTimerRef = useRef | null>(null) const [showAppIconPicker, setShowAppIconPicker] = useState(false) const [appIcon, setAppIcon] = useState( @@ -137,10 +138,22 @@ const SettingsModal: FC = ({ : { type: 'emoji', icon, background: icon_background! }) }, [appInfo, chat_color_theme, chat_color_theme_inverted, copyright, custom_disclaimer, default_language, description, icon, icon_background, icon_type, icon_url, privacy_policy, show_workflow_steps, title, use_icon_as_answer_icon]) + useEffect(() => { + return () => { + if (hideMoreTimerRef.current) { + clearTimeout(hideMoreTimerRef.current) + hideMoreTimerRef.current = null + } + } + }, []) + const onHide = () => { onClose() - setTimeout(() => { + if (hideMoreTimerRef.current) + clearTimeout(hideMoreTimerRef.current) + hideMoreTimerRef.current = setTimeout(() => { setIsShowMore(false) + hideMoreTimerRef.current = null }, 200) } @@ -231,12 +244,12 @@ const SettingsModal: FC = ({ {/* header */}
-
{t(`${prefixSettings}.title`, { ns: 'appOverview' })}
+
{t(`${prefixSettings}.title`, { ns: 'appOverview' })}
-
+
{t(`${prefixSettings}.modalTip`, { ns: 'appOverview' })}
@@ -245,7 +258,7 @@ const SettingsModal: FC = ({ {/* name & icon */}
-
{t(`${prefixSettings}.webName`, { ns: 'appOverview' })}
+
{t(`${prefixSettings}.webName`, { ns: 'appOverview' })}
= ({
{/* description */}
-
{t(`${prefixSettings}.webDesc`, { ns: 'appOverview' })}
+
{t(`${prefixSettings}.webDesc`, { ns: 'appOverview' })}