From 90d72f5ddf8fc962b6dc9852e3c9238eeea3e0d4 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 27 Aug 2025 17:46:46 +0800 Subject: [PATCH] merge new graph engine --- api/commands.py | 4 +- .../datasets/rag_pipeline/datasource_auth.py | 2 +- .../rag_pipeline_draft_variable.py | 5 +- .../rag_pipeline/rag_pipeline_workflow.py | 64 +-- api/controllers/console/datasets/wraps.py | 4 + api/controllers/console/spec.py | 2 +- .../service_api/dataset/document.py | 3 + api/core/agent/base_agent_runner.py | 4 +- .../app/apps/advanced_chat/app_generator.py | 2 +- api/core/app/apps/chat/app_runner.py | 4 +- .../common/workflow_response_converter.py | 2 +- api/core/app/apps/pipeline/pipeline_runner.py | 63 ++- api/core/plugin/impl/datasource.py | 2 +- .../index_processor/index_processor_base.py | 2 +- api/core/schemas/__init__.py | 2 +- api/core/schemas/registry.py | 46 +- api/core/schemas/resolver.py | 191 ++++--- api/core/schemas/schema_manager.py | 23 +- api/core/workflow/entities/variable_pool.py | 8 +- api/core/workflow/enums.py | 4 + api/core/workflow/graph/graph.py | 2 +- .../nodes/datasource/datasource_node.py | 69 +-- .../knowledge_index/knowledge_index_node.py | 14 +- api/models/dataset.py | 6 +- api/models/provider_ids.py | 5 + api/services/dataset_service.py | 22 +- api/services/datasource_provider_service.py | 6 +- .../rag_pipeline_entities.py | 20 +- api/services/rag_pipeline/rag_pipeline.py | 52 +- .../rag_pipeline/rag_pipeline_dsl_service.py | 4 +- .../rag_pipeline_transform_service.py | 5 +- api/tasks/batch_clean_document_task.py | 5 +- .../rag_pipeline/rag_pipeline_run_task.py | 27 +- api/tests/unit_tests/core/schemas/__init__.py | 2 +- .../unit_tests/core/schemas/test_resolver.py | 493 +++++++----------- 35 files changed, 552 insertions(+), 617 deletions(-) diff --git a/api/commands.py b/api/commands.py index 23bf3d65ae..0874b2ffa0 100644 --- a/api/commands.py +++ b/api/commands.py @@ -14,7 +14,7 @@ from sqlalchemy.exc import SQLAlchemyError from configs import dify_config from constants.languages import languages from core.helper import encrypter -from core.plugin.entities.plugin import DatasourceProviderID, PluginInstallationSource +from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.impl.plugin import PluginInstaller from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType @@ -35,7 +35,7 @@ from models.dataset import Document as DatasetDocument from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.oauth import DatasourceOauthParamConfig, DatasourceProvider from models.provider import Provider, ProviderModel -from models.provider_ids import ToolProviderID +from models.provider_ids import DatasourceProviderID, ToolProviderID from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding from models.tools import ToolOAuthSystemClient from services.account_service import AccountService, RegisterService, TenantService diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index a307ca0945..1a845cf326 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -11,10 +11,10 @@ from controllers.console.wraps import ( setup_required, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.plugin.entities.plugin import DatasourceProviderID from core.plugin.impl.oauth import OAuthHandler from libs.helper import StrLen from libs.login import login_required +from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 18cfac4fd8..cb95c2df43 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -17,10 +17,11 @@ from core.variables.segment_group import SegmentGroup from core.variables.segments import ArrayFileSegment, FileSegment, Segment from core.variables.types import SegmentType from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type from libs.login import current_user, login_required -from models import db +from models.account import Account from models.dataset import Pipeline from models.workflow import WorkflowDraftVariable from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -131,7 +132,7 @@ def _api_prerequisite(f): @account_initialization_required @get_rag_pipeline def wrapper(*args, **kwargs): - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() return f(*args, **kwargs) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 0b48cb594b..f1a1f5f2b8 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -62,7 +62,7 @@ class DraftRagPipelineApi(Resource): Get draft rag pipeline's workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() # fetch draft workflow by app_model @@ -84,7 +84,7 @@ class DraftRagPipelineApi(Resource): Sync draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() content_type = request.headers.get("Content-Type", "") @@ -161,7 +161,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource): Run draft workflow iteration node """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() if not isinstance(current_user, Account): @@ -198,7 +198,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource): Run draft workflow loop node """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() if not isinstance(current_user, Account): @@ -235,7 +235,7 @@ class DraftRagPipelineRunApi(Resource): Run draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() if not isinstance(current_user, Account): @@ -272,7 +272,7 @@ class PublishedRagPipelineRunApi(Resource): Run published workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() if not isinstance(current_user, Account): @@ -384,8 +384,6 @@ class PublishedRagPipelineRunApi(Resource): # # return result # - - class RagPipelinePublishedDatasourceNodeRunApi(Resource): @setup_required @login_required @@ -396,7 +394,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): Run rag pipeline datasource """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() if not isinstance(current_user, Account): @@ -441,10 +439,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): Run rag pipeline datasource """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() @@ -487,10 +482,7 @@ class RagPipelineDraftNodeRunApi(Resource): Run draft workflow node """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() @@ -519,7 +511,7 @@ class RagPipelineTaskStopApi(Resource): Stop workflow task """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) @@ -538,7 +530,7 @@ class PublishedRagPipelineApi(Resource): Get published pipeline """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() if not pipeline.is_published: return None @@ -558,10 +550,7 @@ class PublishedRagPipelineApi(Resource): Publish workflow """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() rag_pipeline_service = RagPipelineService() @@ -595,7 +584,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource): Get default block config """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() # Get default block configs @@ -613,7 +602,7 @@ class DefaultRagPipelineBlockConfigApi(Resource): Get default block config """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() if not isinstance(current_user, Account): @@ -659,7 +648,7 @@ class PublishedAllRagPipelineApi(Resource): """ Get published workflows """ - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() @@ -708,10 +697,7 @@ class RagPipelineByIdApi(Resource): Update workflow attributes """ # Check permission - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() @@ -767,7 +753,7 @@ class PublishedRagPipelineSecondStepApi(Resource): Get second step parameters of rag pipeline """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, location="args") @@ -792,7 +778,7 @@ class PublishedRagPipelineFirstStepApi(Resource): Get first step parameters of rag pipeline """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, location="args") @@ -817,7 +803,7 @@ class DraftRagPipelineFirstStepApi(Resource): Get first step parameters of rag pipeline """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, location="args") @@ -842,7 +828,7 @@ class DraftRagPipelineSecondStepApi(Resource): Get second step parameters of rag pipeline """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, location="args") @@ -926,8 +912,11 @@ class DatasourceListApi(Resource): @account_initialization_required def get(self): user = current_user - + if not isinstance(user, Account): + raise Forbidden() tenant_id = user.current_tenant_id + if not tenant_id: + raise Forbidden() return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id)) @@ -974,10 +963,7 @@ class RagPipelineDatasourceVariableApi(Resource): """ Set datasource variables """ - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): + if not isinstance(current_user, Account) or not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py index 32fd47fd36..26783d8cf8 100644 --- a/api/controllers/console/datasets/wraps.py +++ b/api/controllers/console/datasets/wraps.py @@ -5,6 +5,7 @@ from typing import Optional from controllers.console.datasets.error import PipelineNotFoundError from extensions.ext_database import db from libs.login import current_user +from models.account import Account from models.dataset import Pipeline @@ -17,6 +18,9 @@ def get_rag_pipeline( if not kwargs.get("pipeline_id"): raise ValueError("missing pipeline_id in path parameters") + if not isinstance(current_user, Account): + raise ValueError("current_user is not an account") + pipeline_id = kwargs.get("pipeline_id") pipeline_id = str(pipeline_id) diff --git a/api/controllers/console/spec.py b/api/controllers/console/spec.py index 8e10f95dc2..ca54715fe0 100644 --- a/api/controllers/console/spec.py +++ b/api/controllers/console/spec.py @@ -32,4 +32,4 @@ class SpecSchemaDefinitionsApi(Resource): return [], 200 -api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions") \ No newline at end of file +api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 43232229c8..aede0de5b6 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -133,6 +133,9 @@ class DocumentAddByTextApi(DatasetApiResource): # validate args DocumentService.document_create_args_validate(knowledge_config) + if not current_user: + raise ValueError("current_user is required") + try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index f7c83f927f..1d0fe2f6a0 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -90,7 +90,9 @@ class BaseAgentRunner(AppRunner): tenant_id=tenant_id, dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [], retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, - return_resource=app_config.additional_features.show_retrieve_source, + return_resource=( + app_config.additional_features.show_retrieve_source if app_config.additional_features else False + ), invoke_from=application_generate_entity.invoke_from, hit_callback=hit_callback, user_id=user_id, diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 52ae20ee16..6e02b0ebd2 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -154,7 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): if invoke_from == InvokeFrom.DEBUGGER: # always enable retriever resource in debugger mode - app_config.additional_features.show_retrieve_source = True + app_config.additional_features.show_retrieve_source = True # type: ignore workflow_run_id = str(uuid.uuid4()) # init application generate entity diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 894d7906d5..09b13e901a 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -162,7 +162,9 @@ class ChatAppRunner(AppRunner): config=app_config.dataset, query=query, invoke_from=application_generate_entity.invoke_from, - show_retrieve_source=app_config.additional_features.show_retrieve_source, + show_retrieve_source=( + app_config.additional_features.show_retrieve_source if app_config.additional_features else False + ), hit_callback=hit_callback, memory=memory, message_id=message.id, diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 444cedf7f6..b3a94e6d9f 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -36,8 +36,8 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.file import FILE_MODEL_IDENTITY, File -from core.tools.entities.tool_entities import ToolProviderType from core.plugin.impl.datasource import PluginDatasourceManager +from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.variables.segments import ArrayFileSegment, FileSegment, Segment from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 9c97b8109f..bc2e5c1bce 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -1,8 +1,7 @@ import logging -from collections.abc import Mapping -from typing import Any, Optional, cast +import time +from typing import Optional, cast -from configs import dify_config from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -11,10 +10,12 @@ from core.app.entities.app_invoke_entities import ( RagPipelineGenerateEntity, ) from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput -from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback +from core.workflow.entities.graph_init_params import GraphInitParams +from core.workflow.entities.graph_runtime_state import GraphRuntimeState from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent -from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph import Graph +from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent +from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry @@ -22,7 +23,7 @@ from extensions.ext_database import db from models.dataset import Document, Pipeline from models.enums import UserFrom from models.model import EndUser -from models.workflow import Workflow, WorkflowType +from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -84,24 +85,30 @@ class PipelineRunner(WorkflowBasedAppRunner): db.session.close() - workflow_callbacks: list[WorkflowCallback] = [] - if dify_config.DEBUG: - workflow_callbacks.append(WorkflowLoggingCallback()) - # if only single iteration run is requested if self.application_generate_entity.single_iteration_run: + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool.empty(), + start_at=time.time(), + ) # if only single iteration run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( workflow=workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, user_inputs=self.application_generate_entity.single_iteration_run.inputs, + graph_runtime_state=graph_runtime_state, ) elif self.application_generate_entity.single_loop_run: + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool.empty(), + start_at=time.time(), + ) # if only single loop run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( workflow=workflow, node_id=self.application_generate_entity.single_loop_run.node_id, user_inputs=self.application_generate_entity.single_loop_run.inputs, + graph_runtime_state=graph_runtime_state, ) else: inputs = self.application_generate_entity.inputs @@ -121,6 +128,7 @@ class PipelineRunner(WorkflowBasedAppRunner): datasource_info=self.application_generate_entity.datasource_info, invoke_from=self.application_generate_entity.invoke_from.value, ) + rag_pipeline_variables = [] if workflow.rag_pipeline_variables: for v in workflow.rag_pipeline_variables: @@ -143,11 +151,13 @@ class PipelineRunner(WorkflowBasedAppRunner): conversation_variables=[], rag_pipeline_variables=rag_pipeline_variables, ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init graph graph = self._init_rag_pipeline_graph( - graph_config=workflow.graph_dict, + graph_runtime_state=graph_runtime_state, start_node_id=self.application_generate_entity.start_node_id, + workflow=workflow, ) # RUN WORKFLOW @@ -155,7 +165,6 @@ class PipelineRunner(WorkflowBasedAppRunner): tenant_id=workflow.tenant_id, app_id=workflow.app_id, workflow_id=workflow.id, - workflow_type=WorkflowType.value_of(workflow.type), graph=graph, graph_config=workflow.graph_dict, user_id=self.application_generate_entity.user_id, @@ -166,11 +175,10 @@ class PipelineRunner(WorkflowBasedAppRunner): ), invoke_from=self.application_generate_entity.invoke_from, call_depth=self.application_generate_entity.call_depth, - variable_pool=variable_pool, - thread_pool_id=self.workflow_thread_pool_id, + graph_runtime_state=graph_runtime_state, ) - generator = workflow_entry.run(callbacks=workflow_callbacks) + generator = workflow_entry.run() for event in generator: self._update_document_status( @@ -194,10 +202,13 @@ class PipelineRunner(WorkflowBasedAppRunner): # return workflow return workflow - def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph: + def _init_rag_pipeline_graph( + self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None + ) -> Graph: """ Init pipeline graph """ + graph_config = workflow.graph_dict if "nodes" not in graph_config or "edges" not in graph_config: raise ValueError("nodes or edges not found in workflow graph") @@ -227,7 +238,23 @@ class PipelineRunner(WorkflowBasedAppRunner): graph_config["nodes"] = real_run_nodes graph_config["edges"] = real_edges # init graph - graph = Graph.init(graph_config=graph_config) + # Create required parameters for Graph.init + graph_init_params = GraphInitParams( + tenant_id=workflow.tenant_id, + app_id=self._app_id, + workflow_id=workflow.id, + graph_config=graph_config, + user_id="", + user_from=UserFrom.ACCOUNT.value, + invoke_from=InvokeFrom.SERVICE_API.value, + call_depth=0, + ) + + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id) if not graph: raise ValueError("graph not found in workflow") diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 8568d9eecd..84087f8104 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -10,13 +10,13 @@ from core.datasource.entities.datasource_entities import ( OnlineDriveDownloadFileRequest, WebsiteCrawlMessage, ) -from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDatasourceProviderEntity, ) from core.plugin.impl.base import BasePluginClient from core.schemas.resolver import resolve_dify_schema_refs +from models.provider_ids import DatasourceProviderID, GenericProviderID from services.tools.tools_transform_service import ToolTransformService diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 874049ce8e..379191d7f0 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import Any, TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from configs import dify_config from core.rag.extractor.entity.extract_setting import ExtractSetting diff --git a/api/core/schemas/__init__.py b/api/core/schemas/__init__.py index 863677bd5c..0e3833bf96 100644 --- a/api/core/schemas/__init__.py +++ b/api/core/schemas/__init__.py @@ -2,4 +2,4 @@ from .resolver import resolve_dify_schema_refs -__all__ = ["resolve_dify_schema_refs"] \ No newline at end of file +__all__ = ["resolve_dify_schema_refs"] diff --git a/api/core/schemas/registry.py b/api/core/schemas/registry.py index 64765cee9f..b4cb6d8ae1 100644 --- a/api/core/schemas/registry.py +++ b/api/core/schemas/registry.py @@ -7,7 +7,7 @@ from typing import Any, ClassVar, Optional class SchemaRegistry: """Schema registry manages JSON schemas with version support""" - + _default_instance: ClassVar[Optional["SchemaRegistry"]] = None _lock: ClassVar[threading.Lock] = threading.Lock() @@ -25,41 +25,41 @@ class SchemaRegistry: if cls._default_instance is None: current_dir = Path(__file__).parent schema_dir = current_dir / "builtin" / "schemas" - + registry = cls(str(schema_dir)) registry.load_all_versions() - + cls._default_instance = registry - + return cls._default_instance def load_all_versions(self) -> None: """Scans the schema directory and loads all versions""" if not self.base_dir.exists(): return - + for entry in self.base_dir.iterdir(): if not entry.is_dir(): continue - + version = entry.name if not version.startswith("v"): continue - + self._load_version_dir(version, entry) def _load_version_dir(self, version: str, version_dir: Path) -> None: """Loads all schemas in a version directory""" if not version_dir.exists(): return - + if version not in self.versions: self.versions[version] = {} - + for entry in version_dir.iterdir(): if entry.suffix != ".json": continue - + schema_name = entry.stem self._load_schema(version, schema_name, entry) @@ -68,10 +68,10 @@ class SchemaRegistry: try: with open(schema_path, encoding="utf-8") as f: schema = json.load(f) - + # Store the schema self.versions[version][schema_name] = schema - + # Extract and store metadata uri = f"https://dify.ai/schemas/{version}/{schema_name}.json" metadata = { @@ -81,26 +81,26 @@ class SchemaRegistry: "deprecated": schema.get("deprecated", False), } self.metadata[uri] = metadata - + except (OSError, json.JSONDecodeError) as e: print(f"Warning: failed to load schema {version}/{schema_name}: {e}") - def get_schema(self, uri: str) -> Optional[Any]: """Retrieves a schema by URI with version support""" version, schema_name = self._parse_uri(uri) if not version or not schema_name: return None - + version_schemas = self.versions.get(version) if not version_schemas: return None - + return version_schemas.get(schema_name) def _parse_uri(self, uri: str) -> tuple[str, str]: """Parses a schema URI to extract version and schema name""" from core.schemas.resolver import parse_dify_schema_uri + return parse_dify_schema_uri(uri) def list_versions(self) -> list[str]: @@ -112,19 +112,15 @@ class SchemaRegistry: version_schemas = self.versions.get(version) if not version_schemas: return [] - + return sorted(version_schemas.keys()) def get_all_schemas_for_version(self, version: str = "v1") -> list[Mapping[str, Any]]: """Returns all schemas for a version in the API format""" version_schemas = self.versions.get(version, {}) - + result = [] for schema_name, schema in version_schemas.items(): - result.append({ - "name": schema_name, - "label": schema.get("title", schema_name), - "schema": schema - }) - - return result \ No newline at end of file + result.append({"name": schema_name, "label": schema.get("title", schema_name), "schema": schema}) + + return result diff --git a/api/core/schemas/resolver.py b/api/core/schemas/resolver.py index 3339dd9a6a..1c5dabd79b 100644 --- a/api/core/schemas/resolver.py +++ b/api/core/schemas/resolver.py @@ -19,11 +19,13 @@ _DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$ class SchemaResolutionError(Exception): """Base exception for schema resolution errors""" + pass class CircularReferenceError(SchemaResolutionError): """Raised when a circular reference is detected""" + def __init__(self, ref_uri: str, ref_path: list[str]): self.ref_uri = ref_uri self.ref_path = ref_path @@ -32,6 +34,7 @@ class CircularReferenceError(SchemaResolutionError): class MaxDepthExceededError(SchemaResolutionError): """Raised when maximum resolution depth is exceeded""" + def __init__(self, max_depth: int): self.max_depth = max_depth super().__init__(f"Maximum resolution depth ({max_depth}) exceeded") @@ -39,6 +42,7 @@ class MaxDepthExceededError(SchemaResolutionError): class SchemaNotFoundError(SchemaResolutionError): """Raised when a referenced schema cannot be found""" + def __init__(self, ref_uri: str): self.ref_uri = ref_uri super().__init__(f"Schema not found: {ref_uri}") @@ -47,6 +51,7 @@ class SchemaNotFoundError(SchemaResolutionError): @dataclass class QueueItem: """Represents an item in the BFS queue""" + current: Any parent: Optional[Any] key: Optional[Union[str, int]] @@ -56,39 +61,39 @@ class QueueItem: class SchemaResolver: """Resolver for Dify schema references with caching and optimizations""" - + _cache: dict[str, SchemaDict] = {} _cache_lock = threading.Lock() - + def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10): """ Initialize the schema resolver - + Args: registry: Schema registry to use (defaults to default registry) max_depth: Maximum depth for reference resolution """ self.registry = registry or SchemaRegistry.default_registry() self.max_depth = max_depth - + @classmethod def clear_cache(cls) -> None: """Clear the global schema cache""" with cls._cache_lock: cls._cache.clear() - + def resolve(self, schema: SchemaType) -> SchemaType: """ Resolve all $ref references in the schema - + Performance optimization: quickly checks for $ref presence before processing. - + Args: schema: Schema to resolve - + Returns: Resolved schema with all references expanded - + Raises: CircularReferenceError: If circular reference detected MaxDepthExceededError: If max depth exceeded @@ -96,44 +101,39 @@ class SchemaResolver: """ if not isinstance(schema, (dict, list)): return schema - + # Fast path: if no Dify refs found, return original schema unchanged # This avoids expensive deepcopy and BFS traversal for schemas without refs if not _has_dify_refs(schema): return schema - + # Slow path: schema contains refs, perform full resolution import copy + result = copy.deepcopy(schema) - + # Initialize BFS queue - queue = deque([QueueItem( - current=result, - parent=None, - key=None, - depth=0, - ref_path=set() - )]) - + queue = deque([QueueItem(current=result, parent=None, key=None, depth=0, ref_path=set())]) + while queue: item = queue.popleft() - + # Process the current item self._process_queue_item(queue, item) - + return result - + def _process_queue_item(self, queue: deque, item: QueueItem) -> None: """Process a single queue item""" if isinstance(item.current, dict): self._process_dict(queue, item) elif isinstance(item.current, list): self._process_list(queue, item) - + def _process_dict(self, queue: deque, item: QueueItem) -> None: """Process a dictionary item""" ref_uri = item.current.get("$ref") - + if ref_uri and _is_dify_schema_ref(ref_uri): # Handle $ref resolution self._resolve_ref(queue, item, ref_uri) @@ -144,14 +144,10 @@ class SchemaResolver: next_depth = item.depth + 1 if next_depth >= self.max_depth: raise MaxDepthExceededError(self.max_depth) - queue.append(QueueItem( - current=value, - parent=item.current, - key=key, - depth=next_depth, - ref_path=item.ref_path - )) - + queue.append( + QueueItem(current=value, parent=item.current, key=key, depth=next_depth, ref_path=item.ref_path) + ) + def _process_list(self, queue: deque, item: QueueItem) -> None: """Process a list item""" for idx, value in enumerate(item.current): @@ -159,14 +155,10 @@ class SchemaResolver: next_depth = item.depth + 1 if next_depth >= self.max_depth: raise MaxDepthExceededError(self.max_depth) - queue.append(QueueItem( - current=value, - parent=item.current, - key=idx, - depth=next_depth, - ref_path=item.ref_path - )) - + queue.append( + QueueItem(current=value, parent=item.current, key=idx, depth=next_depth, ref_path=item.ref_path) + ) + def _resolve_ref(self, queue: deque, item: QueueItem, ref_uri: str) -> None: """Resolve a $ref reference""" # Check for circular reference @@ -175,82 +167,78 @@ class SchemaResolver: item.current["$circular_ref"] = True logger.warning("Circular reference detected: %s", ref_uri) return - + # Get resolved schema (from cache or registry) resolved_schema = self._get_resolved_schema(ref_uri) if not resolved_schema: logger.warning("Schema not found: %s", ref_uri) return - + # Update ref path new_ref_path = item.ref_path | {ref_uri} - + # Replace the reference with resolved schema next_depth = item.depth + 1 if next_depth >= self.max_depth: raise MaxDepthExceededError(self.max_depth) - + if item.parent is None: # Root level replacement item.current.clear() item.current.update(resolved_schema) - queue.append(QueueItem( - current=item.current, - parent=None, - key=None, - depth=next_depth, - ref_path=new_ref_path - )) + queue.append( + QueueItem(current=item.current, parent=None, key=None, depth=next_depth, ref_path=new_ref_path) + ) else: # Update parent container item.parent[item.key] = resolved_schema.copy() - queue.append(QueueItem( - current=item.parent[item.key], - parent=item.parent, - key=item.key, - depth=next_depth, - ref_path=new_ref_path - )) - + queue.append( + QueueItem( + current=item.parent[item.key], + parent=item.parent, + key=item.key, + depth=next_depth, + ref_path=new_ref_path, + ) + ) + def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]: """Get resolved schema from cache or registry""" # Check cache first with self._cache_lock: if ref_uri in self._cache: return self._cache[ref_uri].copy() - + # Fetch from registry schema = self.registry.get_schema(ref_uri) if not schema: return None - + # Clean and cache cleaned = _remove_metadata_fields(schema) with self._cache_lock: self._cache[ref_uri] = cleaned - + return cleaned.copy() def resolve_dify_schema_refs( - schema: SchemaType, - registry: Optional[SchemaRegistry] = None, - max_depth: int = 30 + schema: SchemaType, registry: Optional[SchemaRegistry] = None, max_depth: int = 30 ) -> SchemaType: """ Resolve $ref references in Dify schema to actual schema content - + This is a convenience function that creates a resolver and resolves the schema. Performance optimization: quickly checks for $ref presence before processing. - + Args: schema: Schema object that may contain $ref references registry: Optional schema registry, defaults to default registry max_depth: Maximum depth to prevent infinite loops (default: 30) - + Returns: Schema with all $ref references resolved to actual content - + Raises: CircularReferenceError: If circular reference detected MaxDepthExceededError: If maximum depth exceeded @@ -260,7 +248,7 @@ def resolve_dify_schema_refs( # This avoids expensive deepcopy and BFS traversal for schemas without refs if not _has_dify_refs(schema): return schema - + # Slow path: schema contains refs, perform full resolution resolver = SchemaResolver(registry, max_depth) return resolver.resolve(schema) @@ -269,36 +257,36 @@ def resolve_dify_schema_refs( def _remove_metadata_fields(schema: dict) -> dict: """ Remove metadata fields from schema that shouldn't be included in resolved output - + Args: schema: Schema dictionary - + Returns: Cleaned schema without metadata fields """ # Create a copy and remove metadata fields cleaned = schema.copy() metadata_fields = ["$id", "$schema", "version"] - + for field in metadata_fields: cleaned.pop(field, None) - + return cleaned def _is_dify_schema_ref(ref_uri: Any) -> bool: """ Check if the reference URI is a Dify schema reference - + Args: ref_uri: URI to check - + Returns: True if it's a Dify schema reference """ if not isinstance(ref_uri, str): return False - + # Use pre-compiled pattern for better performance return bool(_DIFY_SCHEMA_PATTERN.match(ref_uri)) @@ -306,12 +294,12 @@ def _is_dify_schema_ref(ref_uri: Any) -> bool: def _has_dify_refs_recursive(schema: SchemaType) -> bool: """ Recursively check if a schema contains any Dify $ref references - + This is the fallback method when string-based detection is not possible. - + Args: schema: Schema to check for references - + Returns: True if any Dify $ref is found, False otherwise """ @@ -320,18 +308,18 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool: ref_uri = schema.get("$ref") if ref_uri and _is_dify_schema_ref(ref_uri): return True - + # Check nested values for value in schema.values(): if _has_dify_refs_recursive(value): return True - + elif isinstance(schema, list): # Check each item in the list for item in schema: if _has_dify_refs_recursive(item): return True - + # Primitive types don't contain refs return False @@ -339,36 +327,37 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool: def _has_dify_refs_hybrid(schema: SchemaType) -> bool: """ Hybrid detection: fast string scan followed by precise recursive check - + Performance optimization using two-phase detection: 1. Fast string scan to quickly eliminate schemas without $ref 2. Precise recursive validation only for potential candidates - + Args: schema: Schema to check for references - + Returns: True if any Dify $ref is found, False otherwise """ # Phase 1: Fast string-based pre-filtering try: import json - schema_str = json.dumps(schema, separators=(',', ':')) - + + schema_str = json.dumps(schema, separators=(",", ":")) + # Quick elimination: no $ref at all if '"$ref"' not in schema_str: return False - + # Quick elimination: no Dify schema URLs - if 'https://dify.ai/schemas/' not in schema_str: + if "https://dify.ai/schemas/" not in schema_str: return False - + except (TypeError, ValueError, OverflowError): # JSON serialization failed (e.g., circular references, non-serializable objects) # Fall back to recursive detection logger.debug("JSON serialization failed for schema, using recursive detection") return _has_dify_refs_recursive(schema) - + # Phase 2: Precise recursive validation # Only executed for schemas that passed string pre-filtering return _has_dify_refs_recursive(schema) @@ -377,14 +366,14 @@ def _has_dify_refs_hybrid(schema: SchemaType) -> bool: def _has_dify_refs(schema: SchemaType) -> bool: """ Check if a schema contains any Dify $ref references - + Uses hybrid detection for optimal performance: - - Fast string scan for quick elimination + - Fast string scan for quick elimination - Precise recursive check for validation - + Args: schema: Schema to check for references - + Returns: True if any Dify $ref is found, False otherwise """ @@ -394,15 +383,15 @@ def _has_dify_refs(schema: SchemaType) -> bool: def parse_dify_schema_uri(uri: str) -> tuple[str, str]: """ Parse a Dify schema URI to extract version and schema name - + Args: uri: Schema URI to parse - + Returns: Tuple of (version, schema_name) or ("", "") if invalid """ match = _DIFY_SCHEMA_PATTERN.match(uri) if not match: return "", "" - - return match.group(1), match.group(2) \ No newline at end of file + + return match.group(1), match.group(2) diff --git a/api/core/schemas/schema_manager.py b/api/core/schemas/schema_manager.py index 35a3b32fa5..3c9314db66 100644 --- a/api/core/schemas/schema_manager.py +++ b/api/core/schemas/schema_manager.py @@ -13,10 +13,10 @@ class SchemaManager: def get_all_schema_definitions(self, version: str = "v1") -> list[Mapping[str, Any]]: """ Get all JSON Schema definitions for a specific version - + Args: version: Schema version, defaults to v1 - + Returns: Array containing schema definitions, each element contains name and schema fields """ @@ -25,31 +25,28 @@ class SchemaManager: def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Optional[Mapping[str, Any]]: """ Get a specific schema by name - + Args: schema_name: Schema name version: Schema version, defaults to v1 - + Returns: Dictionary containing name and schema, returns None if not found """ uri = f"https://dify.ai/schemas/{version}/{schema_name}.json" schema = self.registry.get_schema(uri) - + if schema: - return { - "name": schema_name, - "schema": schema - } + return {"name": schema_name, "schema": schema} return None def list_available_schemas(self, version: str = "v1") -> list[str]: """ List all available schema names for a specific version - + Args: version: Schema version, defaults to v1 - + Returns: List of schema names """ @@ -58,8 +55,8 @@ class SchemaManager: def list_available_versions(self) -> list[str]: """ List all available schema versions - + Returns: List of versions """ - return self.registry.list_versions() \ No newline at end of file + return self.registry.list_versions() diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index f19128b445..bd03eb15ca 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -68,10 +68,10 @@ class VariablePool(BaseModel): # Add rag pipeline variables to the variable pool if self.rag_pipeline_variables: rag_pipeline_variables_map = defaultdict(dict) - for var in self.rag_pipeline_variables: - node_id = var.variable.belong_to_node_id - key = var.variable.variable - value = var.value + for rag_var in self.rag_pipeline_variables: + node_id = rag_var.variable.belong_to_node_id + key = rag_var.variable.variable + value = rag_var.value rag_pipeline_variables_map[node_id][key] = value for key, value in rag_pipeline_variables_map.items(): self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value) diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 5e0441d340..f04f6ccc55 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -37,12 +37,14 @@ class NodeType(StrEnum): ANSWER = "answer" LLM = "llm" KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" + KNOWLEDGE_INDEX = "knowledge-index" IF_ELSE = "if-else" CODE = "code" TEMPLATE_TRANSFORM = "template-transform" QUESTION_CLASSIFIER = "question-classifier" HTTP_REQUEST = "http-request" TOOL = "tool" + DATASOURCE = "datasource" VARIABLE_AGGREGATOR = "variable-aggregator" LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. LOOP = "loop" @@ -83,6 +85,7 @@ class WorkflowType(StrEnum): WORKFLOW = "workflow" CHAT = "chat" + RAG_PIPELINE = "rag-pipeline" class WorkflowExecutionStatus(StrEnum): @@ -116,6 +119,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output + DATASOURCE_INFO = "datasource_info" class WorkflowNodeExecutionStatus(StrEnum): diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index 66fe100ee0..5bb02c8a7f 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -109,7 +109,7 @@ class Graph: start_node_id = None for nid in root_candidates: node_data = node_configs_map[nid].get("data", {}) - if node_data.get("type") == NodeType.START.value: + if node_data.get("type") in [NodeType.START, NodeType.DATASOURCE]: start_node_id = nid break diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 24a917d305..5fb199558d 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -19,16 +19,14 @@ from core.file.enums import FileTransferMethod, FileType from core.plugin.impl.exc import PluginDaemonClientSideError from core.variables.segments import ArrayAnySegment from core.variables.variables import ArrayAnyVariable -from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.base import BaseNode +from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey +from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig -from core.workflow.nodes.enums import ErrorStrategy, NodeType -from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.nodes.tool.exc import ToolFileError -from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from factories import file_factory from models.model import UploadFile @@ -39,7 +37,7 @@ from .entities import DatasourceNodeData from .exc import DatasourceNodeError, DatasourceParameterError -class DatasourceNode(BaseNode): +class DatasourceNode(Node): """ Datasource Node """ @@ -97,8 +95,8 @@ class DatasourceNode(BaseNode): datasource_type=DatasourceProviderType.value_of(datasource_type), ) except DatasourceNodeError as e: - yield RunCompletedEvent( - run_result=NodeRunResult( + yield StreamCompletedEvent( + node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, @@ -172,8 +170,8 @@ class DatasourceNode(BaseNode): datasource_type=datasource_type, ) case DatasourceProviderType.WEBSITE_CRAWL: - yield RunCompletedEvent( - run_result=NodeRunResult( + yield StreamCompletedEvent( + node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, @@ -204,10 +202,10 @@ class DatasourceNode(BaseNode): size=upload_file.size, storage_key=upload_file.key, ) - variable_pool.add([self.node_id, "file"], file_info) + variable_pool.add([self._node_id, "file"], file_info) # variable_pool.add([self.node_id, "file"], file_info.to_dict()) - yield RunCompletedEvent( - run_result=NodeRunResult( + yield StreamCompletedEvent( + node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, @@ -220,8 +218,8 @@ class DatasourceNode(BaseNode): case _: raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") except PluginDaemonClientSideError as e: - yield RunCompletedEvent( - run_result=NodeRunResult( + yield StreamCompletedEvent( + node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, @@ -230,8 +228,8 @@ class DatasourceNode(BaseNode): ) ) except DatasourceNodeError as e: - yield RunCompletedEvent( - run_result=NodeRunResult( + yield StreamCompletedEvent( + node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, @@ -425,8 +423,10 @@ class DatasourceNode(BaseNode): elif message.type == DatasourceMessage.MessageType.TEXT: assert isinstance(message.message, DatasourceMessage.TextMessage) text += message.message.text - yield RunStreamChunkEvent( - chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=message.message.text, + is_final=False, ) elif message.type == DatasourceMessage.MessageType.JSON: assert isinstance(message.message, DatasourceMessage.JsonMessage) @@ -442,7 +442,11 @@ class DatasourceNode(BaseNode): assert isinstance(message.message, DatasourceMessage.TextMessage) stream_text = f"Link: {message.message.text}\n" text += stream_text - yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=stream_text, + is_final=False, + ) elif message.type == DatasourceMessage.MessageType.VARIABLE: assert isinstance(message.message, DatasourceMessage.VariableMessage) variable_name = message.message.variable_name @@ -454,17 +458,24 @@ class DatasourceNode(BaseNode): variables[variable_name] = "" variables[variable_name] += variable_value - yield RunStreamChunkEvent( - chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] + yield StreamChunkEvent( + selector=[self._node_id, variable_name], + chunk=variable_value, + is_final=False, ) else: variables[variable_name] = variable_value elif message.type == DatasourceMessage.MessageType.FILE: assert message.meta is not None files.append(message.meta["file"]) - - yield RunCompletedEvent( - run_result=NodeRunResult( + # mark the end of the stream + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk="", + is_final=True, + ) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"json": json, "files": files, **variables, "text": text}, metadata={ @@ -526,9 +537,9 @@ class DatasourceNode(BaseNode): tenant_id=self.tenant_id, ) if file: - variable_pool.add([self.node_id, "file"], file) - yield RunCompletedEvent( - run_result=NodeRunResult( + variable_pool.add([self._node_id, "file"], file) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 0acbc513fe..2880518b94 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -9,16 +9,15 @@ from sqlalchemy import func from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey +from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey +from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig -from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.base.node import Node from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -from ..base import BaseNode from .entities import KnowledgeIndexNodeData from .exc import ( KnowledgeIndexNodeError, @@ -35,7 +34,7 @@ default_retrieval_model = { } -class KnowledgeIndexNode(BaseNode): +class KnowledgeIndexNode(Node): _node_data: KnowledgeIndexNodeData _node_type = NodeType.KNOWLEDGE_INDEX @@ -93,15 +92,12 @@ class KnowledgeIndexNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, - process_data=None, outputs=outputs, ) results = self._invoke_knowledge_index( dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=results) except KnowledgeIndexNodeError as e: logger.warning("Error when running knowledge index node") diff --git a/api/models/dataset.py b/api/models/dataset.py index 9c3150ca5c..ff9559d7d8 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -172,7 +172,7 @@ class Dataset(Base): ) @property - def doc_form(self): + def doc_form(self) -> Optional[str]: if self.chunk_structure: return self.chunk_structure document = db.session.query(Document).filter(Document.dataset_id == self.id).first() @@ -424,7 +424,7 @@ class Document(Base): return status @property - def data_source_info_dict(self): + def data_source_info_dict(self) -> dict[str, Any]: if self.data_source_info: try: data_source_info_dict = json.loads(self.data_source_info) @@ -432,7 +432,7 @@ class Document(Base): data_source_info_dict = {} return data_source_info_dict - return None + return {} @property def data_source_detail_dict(self): diff --git a/api/models/provider_ids.py b/api/models/provider_ids.py index 0a5af8cc77..98dc67f2f3 100644 --- a/api/models/provider_ids.py +++ b/api/models/provider_ids.py @@ -52,3 +52,8 @@ class ToolProviderID(GenericProviderID): if self.organization == "langgenius": if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]: self.plugin_name = f"{self.provider_name}_tool" + + +class DatasourceProviderID(GenericProviderID): + def __init__(self, value: str, is_hardcoded: bool = False) -> None: + super().__init__(value, is_hardcoded) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index a559d5dc86..46b2c61800 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -718,9 +718,9 @@ class DatasetService: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=knowledge_configuration.embedding_model_provider, + provider=knowledge_configuration.embedding_model_provider or "", model_type=ModelType.TEXT_EMBEDDING, - model=knowledge_configuration.embedding_model, + model=knowledge_configuration.embedding_model or "", ) dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider @@ -1159,7 +1159,7 @@ class DocumentService: return documents = db.session.query(Document).where(Document.id.in_(document_ids)).all() file_ids = [ - document.data_source_info_dict["upload_file_id"] + document.data_source_info_dict.get("upload_file_id", "") for document in documents if document.data_source_type == "upload_file" ] @@ -1281,7 +1281,7 @@ class DocumentService: account: Account | Any, dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", - ): + ) -> tuple[list[Document], str]: # check doc_form DatasetService.check_doc_form(dataset, knowledge_config.doc_form) # check document limit @@ -1386,7 +1386,7 @@ class DocumentService: "Invalid process rule mode: %s, can not find dataset process rule", process_rule.mode, ) - return + return [], "" db.session.add(dataset_process_rule) db.session.flush() lock_name = f"add_document_lock_dataset_id_{dataset.id}" @@ -2595,7 +2595,9 @@ class SegmentService: return segment_data_list @classmethod - def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): + def update_segment( + cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset + ) -> DocumentSegment: indexing_cache_key = f"segment_{segment.id}_indexing" cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: @@ -2764,6 +2766,8 @@ class SegmentService: segment.error = str(e) db.session.commit() new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() + if not new_segment: + raise ValueError("new_segment is not found") return new_segment @classmethod @@ -2804,7 +2808,11 @@ class SegmentService: index_node_ids = [seg.index_node_id for seg in segments] total_words = sum(seg.word_count for seg in segments) - document.word_count -= total_words + if document.word_count is None: + document.word_count = 0 + else: + document.word_count = max(0, document.word_count - total_words) + db.session.add(document) delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 307ee7867d..c28175c767 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -11,7 +11,6 @@ from core.helper import encrypter from core.helper.name_generator import generate_incremental_name from core.helper.provider_cache import NoOpProviderCredentialCache from core.model_runtime.entities.provider_entities import FormType -from core.plugin.entities.plugin import DatasourceProviderID from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import CredentialType @@ -19,6 +18,7 @@ from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncry from extensions.ext_database import db from extensions.ext_redis import redis_client from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider +from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) @@ -809,9 +809,7 @@ class DatasourceProviderService: credentials = self.list_datasource_credentials( tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id ) - redirect_uri = ( - f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback" - ) + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback" datasource_credentials.append( { "provider": datasource.provider, diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index 77d72544ae..e215a89c15 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -1,6 +1,6 @@ from typing import Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, field_validator class IconInfo(BaseModel): @@ -110,7 +110,21 @@ class KnowledgeConfiguration(BaseModel): chunk_structure: str indexing_technique: Literal["high_quality", "economy"] - embedding_model_provider: Optional[str] = "" - embedding_model: Optional[str] = "" + embedding_model_provider: str = "" + embedding_model: str = "" keyword_number: Optional[int] = 10 retrieval_model: RetrievalSetting + + @field_validator("embedding_model_provider", mode="before") + @classmethod + def validate_embedding_model_provider(cls, v): + if v is None: + return "" + return v + + @field_validator("embedding_model", mode="before") + @classmethod + def validate_embedding_model(cls, v): + if v is None: + return "" + return v diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index deb645273f..05d74f3692 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -28,26 +28,23 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.rag.entities.event import ( - BaseDatasourceEvent, DatasourceCompletedEvent, DatasourceErrorEvent, DatasourceProcessingEvent, ) from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable -from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from core.workflow.enums import SystemVariableKey +from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.graph_engine.entities.event import InNodeEvent -from core.workflow.nodes.base.node import BaseNode -from core.workflow.nodes.enums import ErrorStrategy, NodeType -from core.workflow.nodes.event.event import RunCompletedEvent -from core.workflow.nodes.event.types import NodeEvent +from core.workflow.graph_events.base import GraphNodeEventBase +from core.workflow.node_events.base import NodeRunResult +from core.workflow.node_events.node import StreamCompletedEvent +from core.workflow.nodes.base.node import Node from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.repositories.workflow_node_execution_repository import OrderConfig from core.workflow.system_variable import SystemVariable @@ -105,12 +102,13 @@ class RagPipelineService: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) + built_in_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) + return built_in_result else: mode = "customized" retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) - return result + customized_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) + return customized_result @classmethod def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity): @@ -471,7 +469,7 @@ class RagPipelineService: datasource_type: str, is_published: bool, credential_id: Optional[str] = None, - ) -> Generator[BaseDatasourceEvent, None, None]: + ) -> Generator[Mapping[str, Any], None, None]: """ Run published workflow datasource """ @@ -563,9 +561,9 @@ class RagPipelineService: user_id=account.id, request=OnlineDriveBrowseFilesRequest( bucket=user_inputs.get("bucket"), - prefix=user_inputs.get("prefix"), + prefix=user_inputs.get("prefix", ""), max_keys=user_inputs.get("max_keys", 20), - start_after=user_inputs.get("start_after"), + next_page_parameters=user_inputs.get("next_page_parameters"), ), provider_type=datasource_runtime.datasource_provider_type(), ) @@ -600,7 +598,7 @@ class RagPipelineService: end_time = time.time() if message.result.status == "completed": crawl_event = DatasourceCompletedEvent( - data=message.result.web_info_list, + data=message.result.web_info_list or [], total=message.result.total, completed=message.result.completed, time_consuming=round(end_time - start_time, 2), @@ -681,9 +679,9 @@ class RagPipelineService: datasource_runtime.get_online_document_page_content( user_id=account.id, datasource_parameters=GetOnlineDocumentPageContentRequest( - workspace_id=user_inputs.get("workspace_id"), - page_id=user_inputs.get("page_id"), - type=user_inputs.get("type"), + workspace_id=user_inputs.get("workspace_id", ""), + page_id=user_inputs.get("page_id", ""), + type=user_inputs.get("type", ""), ), provider_type=datasource_type, ) @@ -740,7 +738,7 @@ class RagPipelineService: def _handle_node_run_result( self, - getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]], + getter: Callable[[], tuple[Node, Generator[GraphNodeEventBase, None, None]]], start_at: float, tenant_id: str, node_id: str, @@ -758,17 +756,16 @@ class RagPipelineService: node_run_result: NodeRunResult | None = None for event in generator: - if isinstance(event, RunCompletedEvent): - node_run_result = event.run_result - + if isinstance(event, StreamCompletedEvent): + node_run_result = event.node_run_result # sign output files - node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {} break if not node_run_result: raise ValueError("Node run failed with no run result") # single step debug mode error handling return - if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.continue_on_error: + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.error_strategy: node_error_args: dict[str, Any] = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": node_run_result.error, @@ -808,7 +805,7 @@ class RagPipelineService: workflow_id=node_instance.workflow_id, index=1, node_id=node_id, - node_type=node_instance.type_, + node_type=node_instance.node_type, title=node_instance.title, elapsed_time=time.perf_counter() - start_at, finished_at=datetime.now(UTC).replace(tzinfo=None), @@ -1148,7 +1145,7 @@ class RagPipelineService: .first() ) return node_exec - + def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account | EndUser): # fetch draft workflow by app_model draft_workflow = self.get_draft_workflow(pipeline=pipeline) @@ -1208,6 +1205,3 @@ class RagPipelineService: ) session.commit() return workflow_node_execution_db_model - - - diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 8d288307ce..8447d4f16f 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -23,8 +23,8 @@ from core.helper import ssrf_proxy from core.helper.name_generator import generate_incremental_name from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import PluginDependency +from core.workflow.enums import NodeType from core.workflow.nodes.datasource.entities import DatasourceNodeData -from core.workflow.nodes.enums import NodeType from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData @@ -281,7 +281,7 @@ class RagPipelineDslService: icon = icon_info.icon icon_background = icon_info.icon_background icon_url = icon_info.icon_url - else: + else: icon_type = data.get("rag_pipeline", {}).get("icon_type") icon = data.get("rag_pipeline", {}).get("icon") icon_background = data.get("rag_pipeline", {}).get("icon_background") diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 3ab63e90c1..43eeb49a35 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -1,6 +1,7 @@ import json from datetime import UTC, datetime from pathlib import Path +from typing import Optional from uuid import uuid4 import yaml @@ -87,7 +88,7 @@ class RagPipelineTransformService: "status": "success", } - def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str): + def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: Optional[str]): if doc_form == "text_model": match datasource_type: case "upload_file": @@ -148,7 +149,7 @@ class RagPipelineTransformService: return node def _deal_knowledge_index( - self, dataset: Dataset, doc_form: str, indexing_technique: str, retrieval_model: dict, node: dict + self, dataset: Dataset, doc_form: str, indexing_technique: Optional[str], retrieval_model: dict, node: dict ): knowledge_configuration_dict = node.get("data", {}) knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict) diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 08e2c4a556..7a72c27b0c 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -1,5 +1,6 @@ import logging import time +from typing import Optional import click from celery import shared_task @@ -15,7 +16,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]): +def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: Optional[str], file_ids: list[str]): """ Clean document when document deleted. :param document_ids: document ids @@ -29,6 +30,8 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form start_at = time.perf_counter() try: + if not doc_form: + raise ValueError("doc_form is required") dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 9db8d9ad4d..ff31be5e93 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -21,14 +21,16 @@ from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom @shared_task(queue="dataset") -def rag_pipeline_run_task(pipeline_id: str, - application_generate_entity: dict, - user_id: str, - tenant_id: str, - workflow_id: str, - streaming: bool, - workflow_execution_id: str | None = None, - workflow_thread_pool_id: str | None = None): +def rag_pipeline_run_task( + pipeline_id: str, + application_generate_entity: dict, + user_id: str, + tenant_id: str, + workflow_id: str, + streaming: bool, + workflow_execution_id: str | None = None, + workflow_thread_pool_id: str | None = None, +): """ Async Run rag pipeline :param pipeline_id: Pipeline ID @@ -94,18 +96,19 @@ def rag_pipeline_run_task(pipeline_id: str, with current_app.app_context(): # Set the user directly in g for preserve_flask_contexts g._login_user = account - + # Copy context for thread (after setting user) context = contextvars.copy_context() - + # Get Flask app object in the main thread where app context exists flask_app = current_app._get_current_object() # type: ignore - + # Create a wrapper function that passes user context def _run_with_user_context(): # Don't create a new app context here - let _generate handle it # Just ensure the user is available in contextvars from core.app.apps.pipeline.pipeline_generator import PipelineGenerator + pipeline_generator = PipelineGenerator() pipeline_generator._generate( flask_app=flask_app, @@ -120,7 +123,7 @@ def rag_pipeline_run_task(pipeline_id: str, streaming=streaming, workflow_thread_pool_id=workflow_thread_pool_id, ) - + # Create and start worker thread worker_thread = threading.Thread(target=_run_with_user_context) worker_thread.start() diff --git a/api/tests/unit_tests/core/schemas/__init__.py b/api/tests/unit_tests/core/schemas/__init__.py index e0072207e8..03ced3c3c9 100644 --- a/api/tests/unit_tests/core/schemas/__init__.py +++ b/api/tests/unit_tests/core/schemas/__init__.py @@ -1 +1 @@ -# Core schemas unit tests \ No newline at end of file +# Core schemas unit tests diff --git a/api/tests/unit_tests/core/schemas/test_resolver.py b/api/tests/unit_tests/core/schemas/test_resolver.py index 643059e0e8..dba73bde60 100644 --- a/api/tests/unit_tests/core/schemas/test_resolver.py +++ b/api/tests/unit_tests/core/schemas/test_resolver.py @@ -33,18 +33,16 @@ class TestSchemaResolver: def test_simple_ref_resolution(self): """Test resolving a simple $ref to a complete schema""" - schema_with_ref = { - "$ref": "https://dify.ai/schemas/v1/qa_structure.json" - } - + schema_with_ref = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"} + resolved = resolve_dify_schema_refs(schema_with_ref) - + # Should be resolved to the actual qa_structure schema assert resolved["type"] == "object" assert resolved["title"] == "Q&A Structure Schema" assert "qa_chunks" in resolved["properties"] assert resolved["properties"]["qa_chunks"]["type"] == "array" - + # Metadata fields should be removed assert "$id" not in resolved assert "$schema" not in resolved @@ -55,29 +53,24 @@ class TestSchemaResolver: nested_schema = { "type": "object", "properties": { - "file_data": { - "$ref": "https://dify.ai/schemas/v1/file.json" - }, - "metadata": { - "type": "string", - "description": "Additional metadata" - } - } + "file_data": {"$ref": "https://dify.ai/schemas/v1/file.json"}, + "metadata": {"type": "string", "description": "Additional metadata"}, + }, } - + resolved = resolve_dify_schema_refs(nested_schema) - + # Original structure should be preserved assert resolved["type"] == "object" assert "metadata" in resolved["properties"] assert resolved["properties"]["metadata"]["type"] == "string" - + # $ref should be resolved file_schema = resolved["properties"]["file_data"] assert file_schema["type"] == "object" assert file_schema["title"] == "File Schema" assert "name" in file_schema["properties"] - + # Metadata fields should be removed from resolved schema assert "$id" not in file_schema assert "$schema" not in file_schema @@ -87,18 +80,16 @@ class TestSchemaResolver: """Test resolving $refs in array items""" array_schema = { "type": "array", - "items": { - "$ref": "https://dify.ai/schemas/v1/general_structure.json" - }, - "description": "Array of general structures" + "items": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"}, + "description": "Array of general structures", } - + resolved = resolve_dify_schema_refs(array_schema) - + # Array structure should be preserved assert resolved["type"] == "array" assert resolved["description"] == "Array of general structures" - + # Items $ref should be resolved items_schema = resolved["items"] assert items_schema["type"] == "array" @@ -109,20 +100,16 @@ class TestSchemaResolver: external_ref_schema = { "type": "object", "properties": { - "external_data": { - "$ref": "https://example.com/external-schema.json" - }, - "dify_data": { - "$ref": "https://dify.ai/schemas/v1/file.json" - } - } + "external_data": {"$ref": "https://example.com/external-schema.json"}, + "dify_data": {"$ref": "https://dify.ai/schemas/v1/file.json"}, + }, } - + resolved = resolve_dify_schema_refs(external_ref_schema) - + # External $ref should remain unchanged assert resolved["properties"]["external_data"]["$ref"] == "https://example.com/external-schema.json" - + # Dify $ref should be resolved assert resolved["properties"]["dify_data"]["type"] == "object" assert resolved["properties"]["dify_data"]["title"] == "File Schema" @@ -132,22 +119,14 @@ class TestSchemaResolver: simple_schema = { "type": "object", "properties": { - "name": { - "type": "string", - "description": "Name field" - }, - "items": { - "type": "array", - "items": { - "type": "number" - } - } + "name": {"type": "string", "description": "Name field"}, + "items": {"type": "array", "items": {"type": "number"}}, }, - "required": ["name"] + "required": ["name"], } - + resolved = resolve_dify_schema_refs(simple_schema) - + # Should be identical to input assert resolved == simple_schema assert resolved["type"] == "object" @@ -159,21 +138,16 @@ class TestSchemaResolver: """Test that excessive recursion depth is prevented""" # Create a moderately nested structure deep_schema = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"} - + # Wrap it in fewer layers to make the test more reasonable for _ in range(2): - deep_schema = { - "type": "object", - "properties": { - "nested": deep_schema - } - } - + deep_schema = {"type": "object", "properties": {"nested": deep_schema}} + # Should handle normal cases fine with reasonable depth resolved = resolve_dify_schema_refs(deep_schema, max_depth=25) assert resolved is not None assert resolved["type"] == "object" - + # Should raise error with very low max_depth with pytest.raises(MaxDepthExceededError) as exc_info: resolve_dify_schema_refs(deep_schema, max_depth=5) @@ -185,12 +159,12 @@ class TestSchemaResolver: mock_registry = MagicMock() mock_registry.get_schema.side_effect = lambda uri: { "$ref": "https://dify.ai/schemas/v1/circular.json", - "type": "object" + "type": "object", } - + schema = {"$ref": "https://dify.ai/schemas/v1/circular.json"} resolved = resolve_dify_schema_refs(schema, registry=mock_registry) - + # Should mark circular reference assert "$circular_ref" in resolved @@ -199,10 +173,10 @@ class TestSchemaResolver: # Mock registry that returns None for unknown schemas mock_registry = MagicMock() mock_registry.get_schema.return_value = None - + schema = {"$ref": "https://dify.ai/schemas/v1/unknown.json"} resolved = resolve_dify_schema_refs(schema, registry=mock_registry) - + # Should keep the original $ref when schema not found assert resolved["$ref"] == "https://dify.ai/schemas/v1/unknown.json" @@ -217,25 +191,25 @@ class TestSchemaResolver: def test_cache_functionality(self): """Test that caching works correctly""" schema = {"$ref": "https://dify.ai/schemas/v1/file.json"} - + # First resolution should fetch from registry resolved1 = resolve_dify_schema_refs(schema) - + # Mock the registry to return different data with patch.object(self.registry, "get_schema") as mock_get: mock_get.return_value = {"type": "different"} - + # Second resolution should use cache resolved2 = resolve_dify_schema_refs(schema) - + # Should be the same as first resolution (from cache) assert resolved1 == resolved2 # Mock should not have been called mock_get.assert_not_called() - + # Clear cache and try again SchemaResolver.clear_cache() - + # Now it should fetch again resolved3 = resolve_dify_schema_refs(schema) assert resolved3 == resolved1 @@ -244,14 +218,11 @@ class TestSchemaResolver: """Test that the resolver is thread-safe""" schema = { "type": "object", - "properties": { - f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"} - for i in range(10) - } + "properties": {f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"} for i in range(10)}, } - + results = [] - + def resolve_in_thread(): try: result = resolve_dify_schema_refs(schema) @@ -260,12 +231,12 @@ class TestSchemaResolver: except Exception as e: results.append(e) return False - + # Run multiple threads concurrently with ThreadPoolExecutor(max_workers=10) as executor: futures = [executor.submit(resolve_in_thread) for _ in range(20)] success = all(f.result() for f in futures) - + assert success # All results should be the same first_result = results[0] @@ -276,10 +247,7 @@ class TestSchemaResolver: complex_schema = { "type": "object", "properties": { - "files": { - "type": "array", - "items": {"$ref": "https://dify.ai/schemas/v1/file.json"} - }, + "files": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}}, "nested": { "type": "object", "properties": { @@ -290,21 +258,21 @@ class TestSchemaResolver: "type": "object", "properties": { "general": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"} - } - } - } - } - } - } + }, + }, + }, + }, + }, + }, } - + resolved = resolve_dify_schema_refs(complex_schema, max_depth=20) - + # Check structure is preserved assert resolved["type"] == "object" assert "files" in resolved["properties"] assert "nested" in resolved["properties"] - + # Check refs are resolved assert resolved["properties"]["files"]["items"]["type"] == "object" assert resolved["properties"]["files"]["items"]["title"] == "File Schema" @@ -314,14 +282,14 @@ class TestSchemaResolver: class TestUtilityFunctions: """Test utility functions""" - + def test_is_dify_schema_ref(self): """Test _is_dify_schema_ref function""" # Valid Dify refs assert _is_dify_schema_ref("https://dify.ai/schemas/v1/file.json") assert _is_dify_schema_ref("https://dify.ai/schemas/v2/complex_name.json") assert _is_dify_schema_ref("https://dify.ai/schemas/v999/test-file.json") - + # Invalid refs assert not _is_dify_schema_ref("https://example.com/schema.json") assert not _is_dify_schema_ref("https://dify.ai/other/path.json") @@ -330,61 +298,46 @@ class TestUtilityFunctions: assert not _is_dify_schema_ref(None) assert not _is_dify_schema_ref(123) assert not _is_dify_schema_ref(["list"]) - + def test_has_dify_refs(self): """Test _has_dify_refs function""" # Schemas with Dify refs assert _has_dify_refs({"$ref": "https://dify.ai/schemas/v1/file.json"}) - assert _has_dify_refs({ - "type": "object", - "properties": { - "data": {"$ref": "https://dify.ai/schemas/v1/file.json"} + assert _has_dify_refs( + {"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}} + ) + assert _has_dify_refs([{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/file.json"}]) + assert _has_dify_refs( + { + "type": "array", + "items": { + "type": "object", + "properties": {"nested": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}}, + }, } - }) - assert _has_dify_refs([ - {"type": "string"}, - {"$ref": "https://dify.ai/schemas/v1/file.json"} - ]) - assert _has_dify_refs({ - "type": "array", - "items": { - "type": "object", - "properties": { - "nested": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"} - } - } - }) - + ) + # Schemas without Dify refs assert not _has_dify_refs({"type": "string"}) - assert not _has_dify_refs({ - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "number"} - } - }) - assert not _has_dify_refs([ - {"type": "string"}, - {"type": "number"}, - {"type": "object", "properties": {"name": {"type": "string"}}} - ]) - + assert not _has_dify_refs( + {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "number"}}} + ) + assert not _has_dify_refs( + [{"type": "string"}, {"type": "number"}, {"type": "object", "properties": {"name": {"type": "string"}}}] + ) + # Schemas with non-Dify refs (should return False) assert not _has_dify_refs({"$ref": "https://example.com/schema.json"}) - assert not _has_dify_refs({ - "type": "object", - "properties": { - "external": {"$ref": "https://example.com/external.json"} - } - }) - + assert not _has_dify_refs( + {"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}} + ) + # Primitive types assert not _has_dify_refs("string") assert not _has_dify_refs(123) assert not _has_dify_refs(True) assert not _has_dify_refs(None) - + def test_has_dify_refs_hybrid_vs_recursive(self): """Test that hybrid and recursive detection give same results""" test_schemas = [ @@ -392,29 +345,13 @@ class TestUtilityFunctions: {"type": "string"}, {"type": "object", "properties": {"name": {"type": "string"}}}, [{"type": "string"}, {"type": "number"}], - - # With Dify refs + # With Dify refs {"$ref": "https://dify.ai/schemas/v1/file.json"}, - { - "type": "object", - "properties": { - "data": {"$ref": "https://dify.ai/schemas/v1/file.json"} - } - }, - [ - {"type": "string"}, - {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"} - ], - + {"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}}, + [{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}], # With non-Dify refs {"$ref": "https://example.com/schema.json"}, - { - "type": "object", - "properties": { - "external": {"$ref": "https://example.com/external.json"} - } - }, - + {"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}}, # Complex nested { "type": "object", @@ -422,41 +359,40 @@ class TestUtilityFunctions: "level1": { "type": "object", "properties": { - "level2": { - "type": "array", - "items": {"$ref": "https://dify.ai/schemas/v1/file.json"} - } - } + "level2": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}} + }, } - } + }, }, - # Edge cases {"description": "This mentions $ref but is not a reference"}, {"$ref": "not-a-url"}, - # Primitive types - "string", 123, True, None, [] + "string", + 123, + True, + None, + [], ] - + for schema in test_schemas: hybrid_result = _has_dify_refs_hybrid(schema) recursive_result = _has_dify_refs_recursive(schema) - + assert hybrid_result == recursive_result, f"Mismatch for schema: {schema}" - + def test_parse_dify_schema_uri(self): """Test parse_dify_schema_uri function""" # Valid URIs assert parse_dify_schema_uri("https://dify.ai/schemas/v1/file.json") == ("v1", "file") assert parse_dify_schema_uri("https://dify.ai/schemas/v2/complex_name.json") == ("v2", "complex_name") assert parse_dify_schema_uri("https://dify.ai/schemas/v999/test-file.json") == ("v999", "test-file") - + # Invalid URIs assert parse_dify_schema_uri("https://example.com/schema.json") == ("", "") assert parse_dify_schema_uri("invalid") == ("", "") assert parse_dify_schema_uri("") == ("", "") - + def test_remove_metadata_fields(self): """Test _remove_metadata_fields function""" schema = { @@ -465,68 +401,68 @@ class TestUtilityFunctions: "version": "should be removed", "type": "object", "title": "should remain", - "properties": {} + "properties": {}, } - + cleaned = _remove_metadata_fields(schema) - + assert "$id" not in cleaned assert "$schema" not in cleaned assert "version" not in cleaned assert cleaned["type"] == "object" assert cleaned["title"] == "should remain" assert "properties" in cleaned - + # Original should be unchanged assert "$id" in schema class TestSchemaResolverClass: """Test SchemaResolver class specifically""" - + def test_resolver_initialization(self): """Test resolver initialization""" # Default initialization resolver = SchemaResolver() assert resolver.max_depth == 10 assert resolver.registry is not None - + # Custom initialization custom_registry = MagicMock() resolver = SchemaResolver(registry=custom_registry, max_depth=5) assert resolver.max_depth == 5 assert resolver.registry is custom_registry - + def test_cache_sharing(self): """Test that cache is shared between resolver instances""" SchemaResolver.clear_cache() - + schema = {"$ref": "https://dify.ai/schemas/v1/file.json"} - + # First resolver populates cache resolver1 = SchemaResolver() result1 = resolver1.resolve(schema) - + # Second resolver should use the same cache resolver2 = SchemaResolver() with patch.object(resolver2.registry, "get_schema") as mock_get: result2 = resolver2.resolve(schema) # Should not call registry since it's in cache mock_get.assert_not_called() - + assert result1 == result2 - + def test_resolver_with_list_schema(self): """Test resolver with list as root schema""" list_schema = [ {"$ref": "https://dify.ai/schemas/v1/file.json"}, {"type": "string"}, - {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"} + {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}, ] - + resolver = SchemaResolver() resolved = resolver.resolve(list_schema) - + assert isinstance(resolved, list) assert len(resolved) == 3 assert resolved[0]["type"] == "object" @@ -534,20 +470,20 @@ class TestSchemaResolverClass: assert resolved[1] == {"type": "string"} assert resolved[2]["type"] == "object" assert resolved[2]["title"] == "Q&A Structure Schema" - + def test_cache_performance(self): """Test that caching improves performance""" SchemaResolver.clear_cache() - + # Create a schema with many references to the same schema schema = { "type": "object", "properties": { f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"} for i in range(50) # Reduced to avoid depth issues - } + }, } - + # First run (no cache) - run multiple times to warm up results1 = [] for _ in range(3): @@ -556,9 +492,9 @@ class TestSchemaResolverClass: result1 = resolve_dify_schema_refs(schema) time_no_cache = time.perf_counter() - start results1.append(time_no_cache) - + avg_time_no_cache = sum(results1) / len(results1) - + # Second run (with cache) - run multiple times results2 = [] for _ in range(3): @@ -566,14 +502,14 @@ class TestSchemaResolverClass: result2 = resolve_dify_schema_refs(schema) time_with_cache = time.perf_counter() - start results2.append(time_with_cache) - + avg_time_with_cache = sum(results2) / len(results2) - + # Cache should make it faster (more lenient check) assert result1 == result2 # Cache should provide some performance benefit assert avg_time_with_cache <= avg_time_no_cache - + def test_fast_path_performance_no_refs(self): """Test that schemas without $refs use fast path and avoid deep copying""" # Create a moderately complex schema without any $refs (typical plugin output_schema) @@ -585,16 +521,13 @@ class TestSchemaResolverClass: "properties": { "name": {"type": "string"}, "value": {"type": "number"}, - "items": { - "type": "array", - "items": {"type": "string"} - } - } + "items": {"type": "array", "items": {"type": "string"}}, + }, } for i in range(50) - } + }, } - + # Measure fast path (no refs) performance fast_times = [] for _ in range(10): @@ -602,21 +535,21 @@ class TestSchemaResolverClass: result_fast = resolve_dify_schema_refs(no_refs_schema) elapsed = time.perf_counter() - start fast_times.append(elapsed) - + avg_fast_time = sum(fast_times) / len(fast_times) - + # Most importantly: result should be identical to input (no copying) assert result_fast is no_refs_schema - + # Create schema with $refs for comparison (same structure size) with_refs_schema = { - "type": "object", + "type": "object", "properties": { f"property_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"} for i in range(20) # Fewer to avoid depth issues but still comparable - } + }, } - + # Measure slow path (with refs) performance SchemaResolver.clear_cache() slow_times = [] @@ -626,63 +559,54 @@ class TestSchemaResolverClass: result_slow = resolve_dify_schema_refs(with_refs_schema, max_depth=50) elapsed = time.perf_counter() - start slow_times.append(elapsed) - + avg_slow_time = sum(slow_times) / len(slow_times) - + # The key benefit: fast path should be reasonably fast (main goal is no deep copy) # and definitely avoid the expensive BFS resolution # Even if detection has some overhead, it should still be faster for typical cases print(f"Fast path (no refs): {avg_fast_time:.6f}s") print(f"Slow path (with refs): {avg_slow_time:.6f}s") - + # More lenient check: fast path should be at least somewhat competitive # The main benefit is avoiding deep copy and BFS, not necessarily being 5x faster assert avg_fast_time < avg_slow_time * 2 # Should not be more than 2x slower - + def test_batch_processing_performance(self): """Test performance improvement for batch processing of schemas without refs""" # Simulate the plugin tool scenario: many schemas, most without refs schemas_without_refs = [ { "type": "object", - "properties": { - f"field_{j}": {"type": "string" if j % 2 else "number"} - for j in range(10) - } + "properties": {f"field_{j}": {"type": "string" if j % 2 else "number"} for j in range(10)}, } for i in range(100) ] - + # Test batch processing performance start = time.perf_counter() results = [resolve_dify_schema_refs(schema) for schema in schemas_without_refs] batch_time = time.perf_counter() - start - + # Verify all results are identical to inputs (fast path used) for original, result in zip(schemas_without_refs, results): assert result is original - + # Should be very fast - each schema should take < 0.001 seconds on average avg_time_per_schema = batch_time / len(schemas_without_refs) assert avg_time_per_schema < 0.001 - + def test_has_dify_refs_performance(self): """Test that _has_dify_refs is fast for large schemas without refs""" # Create a very large schema without refs - large_schema = { - "type": "object", - "properties": {} - } - + large_schema = {"type": "object", "properties": {}} + # Add many nested properties current = large_schema for i in range(100): - current["properties"][f"level_{i}"] = { - "type": "object", - "properties": {} - } + current["properties"][f"level_{i}"] = {"type": "object", "properties": {}} current = current["properties"][f"level_{i}"] - + # _has_dify_refs should be fast even for large schemas times = [] for _ in range(50): @@ -690,13 +614,13 @@ class TestSchemaResolverClass: has_refs = _has_dify_refs(large_schema) elapsed = time.perf_counter() - start times.append(elapsed) - + avg_time = sum(times) / len(times) - + # Should be False and fast assert not has_refs assert avg_time < 0.01 # Should complete in less than 10ms - + def test_hybrid_vs_recursive_performance(self): """Test performance comparison between hybrid and recursive detection""" # Create test schemas of different types and sizes @@ -704,16 +628,9 @@ class TestSchemaResolverClass: # Case 1: Small schema without refs (most common case) { "name": "small_no_refs", - "schema": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "value": {"type": "number"} - } - }, - "expected": False + "schema": {"type": "object", "properties": {"name": {"type": "string"}, "value": {"type": "number"}}}, + "expected": False, }, - # Case 2: Medium schema without refs { "name": "medium_no_refs", @@ -725,28 +642,16 @@ class TestSchemaResolverClass: "properties": { "name": {"type": "string"}, "value": {"type": "number"}, - "items": { - "type": "array", - "items": {"type": "string"} - } - } + "items": {"type": "array", "items": {"type": "string"}}, + }, } for i in range(20) - } + }, }, - "expected": False + "expected": False, }, - # Case 3: Large schema without refs - { - "name": "large_no_refs", - "schema": { - "type": "object", - "properties": {} - }, - "expected": False - }, - + {"name": "large_no_refs", "schema": {"type": "object", "properties": {}}, "expected": False}, # Case 4: Schema with Dify refs { "name": "with_dify_refs", @@ -754,45 +659,38 @@ class TestSchemaResolverClass: "type": "object", "properties": { "file": {"$ref": "https://dify.ai/schemas/v1/file.json"}, - "data": {"type": "string"} - } + "data": {"type": "string"}, + }, }, - "expected": True + "expected": True, }, - # Case 5: Schema with non-Dify refs { "name": "with_external_refs", "schema": { - "type": "object", - "properties": { - "external": {"$ref": "https://example.com/schema.json"}, - "data": {"type": "string"} - } + "type": "object", + "properties": {"external": {"$ref": "https://example.com/schema.json"}, "data": {"type": "string"}}, }, - "expected": False - } + "expected": False, + }, ] - + # Add deep nesting to large schema current = test_cases[2]["schema"] for i in range(50): - current["properties"][f"level_{i}"] = { - "type": "object", - "properties": {} - } + current["properties"][f"level_{i}"] = {"type": "object", "properties": {}} current = current["properties"][f"level_{i}"] - + # Performance comparison for test_case in test_cases: schema = test_case["schema"] expected = test_case["expected"] name = test_case["name"] - + # Test correctness first assert _has_dify_refs_hybrid(schema) == expected assert _has_dify_refs_recursive(schema) == expected - + # Measure hybrid performance hybrid_times = [] for _ in range(10): @@ -800,7 +698,7 @@ class TestSchemaResolverClass: result_hybrid = _has_dify_refs_hybrid(schema) elapsed = time.perf_counter() - start hybrid_times.append(elapsed) - + # Measure recursive performance recursive_times = [] for _ in range(10): @@ -808,69 +706,62 @@ class TestSchemaResolverClass: result_recursive = _has_dify_refs_recursive(schema) elapsed = time.perf_counter() - start recursive_times.append(elapsed) - + avg_hybrid = sum(hybrid_times) / len(hybrid_times) avg_recursive = sum(recursive_times) / len(recursive_times) - + print(f"{name}: hybrid={avg_hybrid:.6f}s, recursive={avg_recursive:.6f}s") - + # Results should be identical assert result_hybrid == result_recursive == expected - + # For schemas without refs, hybrid should be competitive or better if not expected: # No refs case # Hybrid might be slightly slower due to JSON serialization overhead, # but should not be dramatically worse assert avg_hybrid < avg_recursive * 5 # At most 5x slower - + def test_string_matching_edge_cases(self): """Test edge cases for string-based detection""" # Case 1: False positive potential - $ref in description schema_false_positive = { "type": "object", "properties": { - "description": { - "type": "string", - "description": "This field explains how $ref works in JSON Schema" - } - } + "description": {"type": "string", "description": "This field explains how $ref works in JSON Schema"} + }, } - + # Both methods should return False assert not _has_dify_refs_hybrid(schema_false_positive) assert not _has_dify_refs_recursive(schema_false_positive) - + # Case 2: Complex URL patterns complex_schema = { "type": "object", "properties": { "config": { - "type": "object", + "type": "object", "properties": { - "dify_url": { - "type": "string", - "default": "https://dify.ai/schemas/info" - }, - "actual_ref": { - "$ref": "https://dify.ai/schemas/v1/file.json" - } - } + "dify_url": {"type": "string", "default": "https://dify.ai/schemas/info"}, + "actual_ref": {"$ref": "https://dify.ai/schemas/v1/file.json"}, + }, } - } + }, } - + # Both methods should return True (due to actual_ref) assert _has_dify_refs_hybrid(complex_schema) assert _has_dify_refs_recursive(complex_schema) - + # Case 3: Non-JSON serializable objects (should fall back to recursive) import datetime + non_serializable = { "type": "object", "timestamp": datetime.datetime.now(), - "data": {"$ref": "https://dify.ai/schemas/v1/file.json"} + "data": {"$ref": "https://dify.ai/schemas/v1/file.json"}, } - + # Hybrid should fall back to recursive and still work assert _has_dify_refs_hybrid(non_serializable) - assert _has_dify_refs_recursive(non_serializable) \ No newline at end of file + assert _has_dify_refs_recursive(non_serializable)