mirror of https://github.com/langgenius/dify.git
merge new graph engine
This commit is contained in:
parent
6c8212d509
commit
90d72f5ddf
|
|
@ -14,7 +14,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin import DatasourceProviderID, PluginInstallationSource
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
|
|
@ -35,7 +35,7 @@ from models.dataset import Document as DatasetDocument
|
|||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.provider_ids import ToolProviderID
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
|
|
|
|||
|
|
@ -11,10 +11,10 @@ from controllers.console.wraps import (
|
|||
setup_required,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.plugin.entities.plugin import DatasourceProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import login_required
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
|
|
|||
|
|
@ -17,10 +17,11 @@ from core.variables.segment_group import SegmentGroup
|
|||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import db
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
|
@ -131,7 +132,7 @@ def _api_prerequisite(f):
|
|||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args, **kwargs):
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class DraftRagPipelineApi(Resource):
|
|||
Get draft rag pipeline's workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
|
|
@ -84,7 +84,7 @@ class DraftRagPipelineApi(Resource):
|
|||
Sync draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
|
@ -161,7 +161,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
|||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -198,7 +198,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||
Run draft workflow loop node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -235,7 +235,7 @@ class DraftRagPipelineRunApi(Resource):
|
|||
Run draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -272,7 +272,7 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
Run published workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -384,8 +384,6 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
#
|
||||
# return result
|
||||
#
|
||||
|
||||
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -396,7 +394,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -441,10 +439,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -487,10 +482,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||
Run draft workflow node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -519,7 +511,7 @@ class RagPipelineTaskStopApi(Resource):
|
|||
Stop workflow task
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
|
@ -538,7 +530,7 @@ class PublishedRagPipelineApi(Resource):
|
|||
Get published pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
if not pipeline.is_published:
|
||||
return None
|
||||
|
|
@ -558,10 +550,7 @@ class PublishedRagPipelineApi(Resource):
|
|||
Publish workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
|
@ -595,7 +584,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
|||
Get default block config
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Get default block configs
|
||||
|
|
@ -613,7 +602,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
|||
Get default block config
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -659,7 +648,7 @@ class PublishedAllRagPipelineApi(Resource):
|
|||
"""
|
||||
Get published workflows
|
||||
"""
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -708,10 +697,7 @@ class RagPipelineByIdApi(Resource):
|
|||
Update workflow attributes
|
||||
"""
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -767,7 +753,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
|||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
|
|
@ -792,7 +778,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
|||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
|
|
@ -817,7 +803,7 @@ class DraftRagPipelineFirstStepApi(Resource):
|
|||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
|
|
@ -842,7 +828,7 @@ class DraftRagPipelineSecondStepApi(Resource):
|
|||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
|
|
@ -926,8 +912,11 @@ class DatasourceListApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
if not isinstance(user, Account):
|
||||
raise Forbidden()
|
||||
tenant_id = user.current_tenant_id
|
||||
if not tenant_id:
|
||||
raise Forbidden()
|
||||
|
||||
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
|
||||
|
||||
|
|
@ -974,10 +963,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
|||
"""
|
||||
Set datasource variables
|
||||
"""
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
|||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
|
|
@ -17,6 +18,9 @@ def get_rag_pipeline(
|
|||
if not kwargs.get("pipeline_id"):
|
||||
raise ValueError("missing pipeline_id in path parameters")
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user is not an account")
|
||||
|
||||
pipeline_id = kwargs.get("pipeline_id")
|
||||
pipeline_id = str(pipeline_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -32,4 +32,4 @@ class SpecSchemaDefinitionsApi(Resource):
|
|||
return [], 200
|
||||
|
||||
|
||||
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")
|
||||
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")
|
||||
|
|
|
|||
|
|
@ -133,6 +133,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
# validate args
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
|
|
|
|||
|
|
@ -90,7 +90,9 @@ class BaseAgentRunner(AppRunner):
|
|||
tenant_id=tenant_id,
|
||||
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
|
||||
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
||||
return_resource=app_config.additional_features.show_retrieve_source,
|
||||
return_resource=(
|
||||
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
|
||||
),
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# init application generate entity
|
||||
|
|
|
|||
|
|
@ -162,7 +162,9 @@ class ChatAppRunner(AppRunner):
|
|||
config=app_config.dataset,
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
show_retrieve_source=(
|
||||
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
|
||||
),
|
||||
hit_callback=hit_callback,
|
||||
memory=memory,
|
||||
message_id=message.id,
|
||||
|
|
|
|||
|
|
@ -36,8 +36,8 @@ from core.app.entities.task_entities import (
|
|||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
|
|
@ -11,10 +10,12 @@ from core.app.entities.app_invoke_entities import (
|
|||
RagPipelineGenerateEntity,
|
||||
)
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
|
|
@ -22,7 +23,7 @@ from extensions.ext_database import db
|
|||
from models.dataset import Document, Pipeline
|
||||
from models.enums import UserFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -84,24 +85,30 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
|
@ -121,6 +128,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
datasource_info=self.application_generate_entity.datasource_info,
|
||||
invoke_from=self.application_generate_entity.invoke_from.value,
|
||||
)
|
||||
|
||||
rag_pipeline_variables = []
|
||||
if workflow.rag_pipeline_variables:
|
||||
for v in workflow.rag_pipeline_variables:
|
||||
|
|
@ -143,11 +151,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
conversation_variables=[],
|
||||
rag_pipeline_variables=rag_pipeline_variables,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
graph = self._init_rag_pipeline_graph(
|
||||
graph_config=workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
start_node_id=self.application_generate_entity.start_node_id,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
|
|
@ -155,7 +165,6 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
|
|
@ -166,11 +175,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
self._update_document_status(
|
||||
|
|
@ -194,10 +202,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
# return workflow
|
||||
return workflow
|
||||
|
||||
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
|
||||
def _init_rag_pipeline_graph(
|
||||
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None
|
||||
) -> Graph:
|
||||
"""
|
||||
Init pipeline graph
|
||||
"""
|
||||
graph_config = workflow.graph_dict
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
|
|
@ -227,7 +238,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
graph_config["nodes"] = real_run_nodes
|
||||
graph_config["edges"] = real_edges
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
|
|
|||
|
|
@ -10,13 +10,13 @@ from core.datasource.entities.datasource_entities import (
|
|||
OnlineDriveDownloadFileRequest,
|
||||
WebsiteCrawlMessage,
|
||||
)
|
||||
from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginDatasourceProviderEntity,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.schemas.resolver import resolve_dify_schema_refs
|
||||
from models.provider_ids import DatasourceProviderID, GenericProviderID
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
|
|
|
|||
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
from .resolver import resolve_dify_schema_refs
|
||||
|
||||
__all__ = ["resolve_dify_schema_refs"]
|
||||
__all__ = ["resolve_dify_schema_refs"]
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from typing import Any, ClassVar, Optional
|
|||
|
||||
class SchemaRegistry:
|
||||
"""Schema registry manages JSON schemas with version support"""
|
||||
|
||||
|
||||
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
|
||||
_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
|
|
@ -25,41 +25,41 @@ class SchemaRegistry:
|
|||
if cls._default_instance is None:
|
||||
current_dir = Path(__file__).parent
|
||||
schema_dir = current_dir / "builtin" / "schemas"
|
||||
|
||||
|
||||
registry = cls(str(schema_dir))
|
||||
registry.load_all_versions()
|
||||
|
||||
|
||||
cls._default_instance = registry
|
||||
|
||||
|
||||
return cls._default_instance
|
||||
|
||||
def load_all_versions(self) -> None:
|
||||
"""Scans the schema directory and loads all versions"""
|
||||
if not self.base_dir.exists():
|
||||
return
|
||||
|
||||
|
||||
for entry in self.base_dir.iterdir():
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
|
||||
|
||||
version = entry.name
|
||||
if not version.startswith("v"):
|
||||
continue
|
||||
|
||||
|
||||
self._load_version_dir(version, entry)
|
||||
|
||||
def _load_version_dir(self, version: str, version_dir: Path) -> None:
|
||||
"""Loads all schemas in a version directory"""
|
||||
if not version_dir.exists():
|
||||
return
|
||||
|
||||
|
||||
if version not in self.versions:
|
||||
self.versions[version] = {}
|
||||
|
||||
|
||||
for entry in version_dir.iterdir():
|
||||
if entry.suffix != ".json":
|
||||
continue
|
||||
|
||||
|
||||
schema_name = entry.stem
|
||||
self._load_schema(version, schema_name, entry)
|
||||
|
||||
|
|
@ -68,10 +68,10 @@ class SchemaRegistry:
|
|||
try:
|
||||
with open(schema_path, encoding="utf-8") as f:
|
||||
schema = json.load(f)
|
||||
|
||||
|
||||
# Store the schema
|
||||
self.versions[version][schema_name] = schema
|
||||
|
||||
|
||||
# Extract and store metadata
|
||||
uri = f"https://dify.ai/schemas/{version}/{schema_name}.json"
|
||||
metadata = {
|
||||
|
|
@ -81,26 +81,26 @@ class SchemaRegistry:
|
|||
"deprecated": schema.get("deprecated", False),
|
||||
}
|
||||
self.metadata[uri] = metadata
|
||||
|
||||
|
||||
except (OSError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: failed to load schema {version}/{schema_name}: {e}")
|
||||
|
||||
|
||||
def get_schema(self, uri: str) -> Optional[Any]:
|
||||
"""Retrieves a schema by URI with version support"""
|
||||
version, schema_name = self._parse_uri(uri)
|
||||
if not version or not schema_name:
|
||||
return None
|
||||
|
||||
|
||||
version_schemas = self.versions.get(version)
|
||||
if not version_schemas:
|
||||
return None
|
||||
|
||||
|
||||
return version_schemas.get(schema_name)
|
||||
|
||||
def _parse_uri(self, uri: str) -> tuple[str, str]:
|
||||
"""Parses a schema URI to extract version and schema name"""
|
||||
from core.schemas.resolver import parse_dify_schema_uri
|
||||
|
||||
return parse_dify_schema_uri(uri)
|
||||
|
||||
def list_versions(self) -> list[str]:
|
||||
|
|
@ -112,19 +112,15 @@ class SchemaRegistry:
|
|||
version_schemas = self.versions.get(version)
|
||||
if not version_schemas:
|
||||
return []
|
||||
|
||||
|
||||
return sorted(version_schemas.keys())
|
||||
|
||||
def get_all_schemas_for_version(self, version: str = "v1") -> list[Mapping[str, Any]]:
|
||||
"""Returns all schemas for a version in the API format"""
|
||||
version_schemas = self.versions.get(version, {})
|
||||
|
||||
|
||||
result = []
|
||||
for schema_name, schema in version_schemas.items():
|
||||
result.append({
|
||||
"name": schema_name,
|
||||
"label": schema.get("title", schema_name),
|
||||
"schema": schema
|
||||
})
|
||||
|
||||
return result
|
||||
result.append({"name": schema_name, "label": schema.get("title", schema_name), "schema": schema})
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -19,11 +19,13 @@ _DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$
|
|||
|
||||
class SchemaResolutionError(Exception):
|
||||
"""Base exception for schema resolution errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CircularReferenceError(SchemaResolutionError):
|
||||
"""Raised when a circular reference is detected"""
|
||||
|
||||
def __init__(self, ref_uri: str, ref_path: list[str]):
|
||||
self.ref_uri = ref_uri
|
||||
self.ref_path = ref_path
|
||||
|
|
@ -32,6 +34,7 @@ class CircularReferenceError(SchemaResolutionError):
|
|||
|
||||
class MaxDepthExceededError(SchemaResolutionError):
|
||||
"""Raised when maximum resolution depth is exceeded"""
|
||||
|
||||
def __init__(self, max_depth: int):
|
||||
self.max_depth = max_depth
|
||||
super().__init__(f"Maximum resolution depth ({max_depth}) exceeded")
|
||||
|
|
@ -39,6 +42,7 @@ class MaxDepthExceededError(SchemaResolutionError):
|
|||
|
||||
class SchemaNotFoundError(SchemaResolutionError):
|
||||
"""Raised when a referenced schema cannot be found"""
|
||||
|
||||
def __init__(self, ref_uri: str):
|
||||
self.ref_uri = ref_uri
|
||||
super().__init__(f"Schema not found: {ref_uri}")
|
||||
|
|
@ -47,6 +51,7 @@ class SchemaNotFoundError(SchemaResolutionError):
|
|||
@dataclass
|
||||
class QueueItem:
|
||||
"""Represents an item in the BFS queue"""
|
||||
|
||||
current: Any
|
||||
parent: Optional[Any]
|
||||
key: Optional[Union[str, int]]
|
||||
|
|
@ -56,39 +61,39 @@ class QueueItem:
|
|||
|
||||
class SchemaResolver:
|
||||
"""Resolver for Dify schema references with caching and optimizations"""
|
||||
|
||||
|
||||
_cache: dict[str, SchemaDict] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10):
|
||||
"""
|
||||
Initialize the schema resolver
|
||||
|
||||
|
||||
Args:
|
||||
registry: Schema registry to use (defaults to default registry)
|
||||
max_depth: Maximum depth for reference resolution
|
||||
"""
|
||||
self.registry = registry or SchemaRegistry.default_registry()
|
||||
self.max_depth = max_depth
|
||||
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear the global schema cache"""
|
||||
with cls._cache_lock:
|
||||
cls._cache.clear()
|
||||
|
||||
|
||||
def resolve(self, schema: SchemaType) -> SchemaType:
|
||||
"""
|
||||
Resolve all $ref references in the schema
|
||||
|
||||
|
||||
Performance optimization: quickly checks for $ref presence before processing.
|
||||
|
||||
|
||||
Args:
|
||||
schema: Schema to resolve
|
||||
|
||||
|
||||
Returns:
|
||||
Resolved schema with all references expanded
|
||||
|
||||
|
||||
Raises:
|
||||
CircularReferenceError: If circular reference detected
|
||||
MaxDepthExceededError: If max depth exceeded
|
||||
|
|
@ -96,44 +101,39 @@ class SchemaResolver:
|
|||
"""
|
||||
if not isinstance(schema, (dict, list)):
|
||||
return schema
|
||||
|
||||
|
||||
# Fast path: if no Dify refs found, return original schema unchanged
|
||||
# This avoids expensive deepcopy and BFS traversal for schemas without refs
|
||||
if not _has_dify_refs(schema):
|
||||
return schema
|
||||
|
||||
|
||||
# Slow path: schema contains refs, perform full resolution
|
||||
import copy
|
||||
|
||||
result = copy.deepcopy(schema)
|
||||
|
||||
|
||||
# Initialize BFS queue
|
||||
queue = deque([QueueItem(
|
||||
current=result,
|
||||
parent=None,
|
||||
key=None,
|
||||
depth=0,
|
||||
ref_path=set()
|
||||
)])
|
||||
|
||||
queue = deque([QueueItem(current=result, parent=None, key=None, depth=0, ref_path=set())])
|
||||
|
||||
while queue:
|
||||
item = queue.popleft()
|
||||
|
||||
|
||||
# Process the current item
|
||||
self._process_queue_item(queue, item)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _process_queue_item(self, queue: deque, item: QueueItem) -> None:
|
||||
"""Process a single queue item"""
|
||||
if isinstance(item.current, dict):
|
||||
self._process_dict(queue, item)
|
||||
elif isinstance(item.current, list):
|
||||
self._process_list(queue, item)
|
||||
|
||||
|
||||
def _process_dict(self, queue: deque, item: QueueItem) -> None:
|
||||
"""Process a dictionary item"""
|
||||
ref_uri = item.current.get("$ref")
|
||||
|
||||
|
||||
if ref_uri and _is_dify_schema_ref(ref_uri):
|
||||
# Handle $ref resolution
|
||||
self._resolve_ref(queue, item, ref_uri)
|
||||
|
|
@ -144,14 +144,10 @@ class SchemaResolver:
|
|||
next_depth = item.depth + 1
|
||||
if next_depth >= self.max_depth:
|
||||
raise MaxDepthExceededError(self.max_depth)
|
||||
queue.append(QueueItem(
|
||||
current=value,
|
||||
parent=item.current,
|
||||
key=key,
|
||||
depth=next_depth,
|
||||
ref_path=item.ref_path
|
||||
))
|
||||
|
||||
queue.append(
|
||||
QueueItem(current=value, parent=item.current, key=key, depth=next_depth, ref_path=item.ref_path)
|
||||
)
|
||||
|
||||
def _process_list(self, queue: deque, item: QueueItem) -> None:
|
||||
"""Process a list item"""
|
||||
for idx, value in enumerate(item.current):
|
||||
|
|
@ -159,14 +155,10 @@ class SchemaResolver:
|
|||
next_depth = item.depth + 1
|
||||
if next_depth >= self.max_depth:
|
||||
raise MaxDepthExceededError(self.max_depth)
|
||||
queue.append(QueueItem(
|
||||
current=value,
|
||||
parent=item.current,
|
||||
key=idx,
|
||||
depth=next_depth,
|
||||
ref_path=item.ref_path
|
||||
))
|
||||
|
||||
queue.append(
|
||||
QueueItem(current=value, parent=item.current, key=idx, depth=next_depth, ref_path=item.ref_path)
|
||||
)
|
||||
|
||||
def _resolve_ref(self, queue: deque, item: QueueItem, ref_uri: str) -> None:
|
||||
"""Resolve a $ref reference"""
|
||||
# Check for circular reference
|
||||
|
|
@ -175,82 +167,78 @@ class SchemaResolver:
|
|||
item.current["$circular_ref"] = True
|
||||
logger.warning("Circular reference detected: %s", ref_uri)
|
||||
return
|
||||
|
||||
|
||||
# Get resolved schema (from cache or registry)
|
||||
resolved_schema = self._get_resolved_schema(ref_uri)
|
||||
if not resolved_schema:
|
||||
logger.warning("Schema not found: %s", ref_uri)
|
||||
return
|
||||
|
||||
|
||||
# Update ref path
|
||||
new_ref_path = item.ref_path | {ref_uri}
|
||||
|
||||
|
||||
# Replace the reference with resolved schema
|
||||
next_depth = item.depth + 1
|
||||
if next_depth >= self.max_depth:
|
||||
raise MaxDepthExceededError(self.max_depth)
|
||||
|
||||
|
||||
if item.parent is None:
|
||||
# Root level replacement
|
||||
item.current.clear()
|
||||
item.current.update(resolved_schema)
|
||||
queue.append(QueueItem(
|
||||
current=item.current,
|
||||
parent=None,
|
||||
key=None,
|
||||
depth=next_depth,
|
||||
ref_path=new_ref_path
|
||||
))
|
||||
queue.append(
|
||||
QueueItem(current=item.current, parent=None, key=None, depth=next_depth, ref_path=new_ref_path)
|
||||
)
|
||||
else:
|
||||
# Update parent container
|
||||
item.parent[item.key] = resolved_schema.copy()
|
||||
queue.append(QueueItem(
|
||||
current=item.parent[item.key],
|
||||
parent=item.parent,
|
||||
key=item.key,
|
||||
depth=next_depth,
|
||||
ref_path=new_ref_path
|
||||
))
|
||||
|
||||
queue.append(
|
||||
QueueItem(
|
||||
current=item.parent[item.key],
|
||||
parent=item.parent,
|
||||
key=item.key,
|
||||
depth=next_depth,
|
||||
ref_path=new_ref_path,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]:
|
||||
"""Get resolved schema from cache or registry"""
|
||||
# Check cache first
|
||||
with self._cache_lock:
|
||||
if ref_uri in self._cache:
|
||||
return self._cache[ref_uri].copy()
|
||||
|
||||
|
||||
# Fetch from registry
|
||||
schema = self.registry.get_schema(ref_uri)
|
||||
if not schema:
|
||||
return None
|
||||
|
||||
|
||||
# Clean and cache
|
||||
cleaned = _remove_metadata_fields(schema)
|
||||
with self._cache_lock:
|
||||
self._cache[ref_uri] = cleaned
|
||||
|
||||
|
||||
return cleaned.copy()
|
||||
|
||||
|
||||
def resolve_dify_schema_refs(
|
||||
schema: SchemaType,
|
||||
registry: Optional[SchemaRegistry] = None,
|
||||
max_depth: int = 30
|
||||
schema: SchemaType, registry: Optional[SchemaRegistry] = None, max_depth: int = 30
|
||||
) -> SchemaType:
|
||||
"""
|
||||
Resolve $ref references in Dify schema to actual schema content
|
||||
|
||||
|
||||
This is a convenience function that creates a resolver and resolves the schema.
|
||||
Performance optimization: quickly checks for $ref presence before processing.
|
||||
|
||||
|
||||
Args:
|
||||
schema: Schema object that may contain $ref references
|
||||
registry: Optional schema registry, defaults to default registry
|
||||
max_depth: Maximum depth to prevent infinite loops (default: 30)
|
||||
|
||||
|
||||
Returns:
|
||||
Schema with all $ref references resolved to actual content
|
||||
|
||||
|
||||
Raises:
|
||||
CircularReferenceError: If circular reference detected
|
||||
MaxDepthExceededError: If maximum depth exceeded
|
||||
|
|
@ -260,7 +248,7 @@ def resolve_dify_schema_refs(
|
|||
# This avoids expensive deepcopy and BFS traversal for schemas without refs
|
||||
if not _has_dify_refs(schema):
|
||||
return schema
|
||||
|
||||
|
||||
# Slow path: schema contains refs, perform full resolution
|
||||
resolver = SchemaResolver(registry, max_depth)
|
||||
return resolver.resolve(schema)
|
||||
|
|
@ -269,36 +257,36 @@ def resolve_dify_schema_refs(
|
|||
def _remove_metadata_fields(schema: dict) -> dict:
|
||||
"""
|
||||
Remove metadata fields from schema that shouldn't be included in resolved output
|
||||
|
||||
|
||||
Args:
|
||||
schema: Schema dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
Cleaned schema without metadata fields
|
||||
"""
|
||||
# Create a copy and remove metadata fields
|
||||
cleaned = schema.copy()
|
||||
metadata_fields = ["$id", "$schema", "version"]
|
||||
|
||||
|
||||
for field in metadata_fields:
|
||||
cleaned.pop(field, None)
|
||||
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def _is_dify_schema_ref(ref_uri: Any) -> bool:
|
||||
"""
|
||||
Check if the reference URI is a Dify schema reference
|
||||
|
||||
|
||||
Args:
|
||||
ref_uri: URI to check
|
||||
|
||||
|
||||
Returns:
|
||||
True if it's a Dify schema reference
|
||||
"""
|
||||
if not isinstance(ref_uri, str):
|
||||
return False
|
||||
|
||||
|
||||
# Use pre-compiled pattern for better performance
|
||||
return bool(_DIFY_SCHEMA_PATTERN.match(ref_uri))
|
||||
|
||||
|
|
@ -306,12 +294,12 @@ def _is_dify_schema_ref(ref_uri: Any) -> bool:
|
|||
def _has_dify_refs_recursive(schema: SchemaType) -> bool:
|
||||
"""
|
||||
Recursively check if a schema contains any Dify $ref references
|
||||
|
||||
|
||||
This is the fallback method when string-based detection is not possible.
|
||||
|
||||
|
||||
Args:
|
||||
schema: Schema to check for references
|
||||
|
||||
|
||||
Returns:
|
||||
True if any Dify $ref is found, False otherwise
|
||||
"""
|
||||
|
|
@ -320,18 +308,18 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool:
|
|||
ref_uri = schema.get("$ref")
|
||||
if ref_uri and _is_dify_schema_ref(ref_uri):
|
||||
return True
|
||||
|
||||
|
||||
# Check nested values
|
||||
for value in schema.values():
|
||||
if _has_dify_refs_recursive(value):
|
||||
return True
|
||||
|
||||
|
||||
elif isinstance(schema, list):
|
||||
# Check each item in the list
|
||||
for item in schema:
|
||||
if _has_dify_refs_recursive(item):
|
||||
return True
|
||||
|
||||
|
||||
# Primitive types don't contain refs
|
||||
return False
|
||||
|
||||
|
|
@ -339,36 +327,37 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool:
|
|||
def _has_dify_refs_hybrid(schema: SchemaType) -> bool:
|
||||
"""
|
||||
Hybrid detection: fast string scan followed by precise recursive check
|
||||
|
||||
|
||||
Performance optimization using two-phase detection:
|
||||
1. Fast string scan to quickly eliminate schemas without $ref
|
||||
2. Precise recursive validation only for potential candidates
|
||||
|
||||
|
||||
Args:
|
||||
schema: Schema to check for references
|
||||
|
||||
|
||||
Returns:
|
||||
True if any Dify $ref is found, False otherwise
|
||||
"""
|
||||
# Phase 1: Fast string-based pre-filtering
|
||||
try:
|
||||
import json
|
||||
schema_str = json.dumps(schema, separators=(',', ':'))
|
||||
|
||||
|
||||
schema_str = json.dumps(schema, separators=(",", ":"))
|
||||
|
||||
# Quick elimination: no $ref at all
|
||||
if '"$ref"' not in schema_str:
|
||||
return False
|
||||
|
||||
|
||||
# Quick elimination: no Dify schema URLs
|
||||
if 'https://dify.ai/schemas/' not in schema_str:
|
||||
if "https://dify.ai/schemas/" not in schema_str:
|
||||
return False
|
||||
|
||||
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
# JSON serialization failed (e.g., circular references, non-serializable objects)
|
||||
# Fall back to recursive detection
|
||||
logger.debug("JSON serialization failed for schema, using recursive detection")
|
||||
return _has_dify_refs_recursive(schema)
|
||||
|
||||
|
||||
# Phase 2: Precise recursive validation
|
||||
# Only executed for schemas that passed string pre-filtering
|
||||
return _has_dify_refs_recursive(schema)
|
||||
|
|
@ -377,14 +366,14 @@ def _has_dify_refs_hybrid(schema: SchemaType) -> bool:
|
|||
def _has_dify_refs(schema: SchemaType) -> bool:
|
||||
"""
|
||||
Check if a schema contains any Dify $ref references
|
||||
|
||||
|
||||
Uses hybrid detection for optimal performance:
|
||||
- Fast string scan for quick elimination
|
||||
- Fast string scan for quick elimination
|
||||
- Precise recursive check for validation
|
||||
|
||||
|
||||
Args:
|
||||
schema: Schema to check for references
|
||||
|
||||
|
||||
Returns:
|
||||
True if any Dify $ref is found, False otherwise
|
||||
"""
|
||||
|
|
@ -394,15 +383,15 @@ def _has_dify_refs(schema: SchemaType) -> bool:
|
|||
def parse_dify_schema_uri(uri: str) -> tuple[str, str]:
|
||||
"""
|
||||
Parse a Dify schema URI to extract version and schema name
|
||||
|
||||
|
||||
Args:
|
||||
uri: Schema URI to parse
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (version, schema_name) or ("", "") if invalid
|
||||
"""
|
||||
match = _DIFY_SCHEMA_PATTERN.match(uri)
|
||||
if not match:
|
||||
return "", ""
|
||||
|
||||
return match.group(1), match.group(2)
|
||||
|
||||
return match.group(1), match.group(2)
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@ class SchemaManager:
|
|||
def get_all_schema_definitions(self, version: str = "v1") -> list[Mapping[str, Any]]:
|
||||
"""
|
||||
Get all JSON Schema definitions for a specific version
|
||||
|
||||
|
||||
Args:
|
||||
version: Schema version, defaults to v1
|
||||
|
||||
|
||||
Returns:
|
||||
Array containing schema definitions, each element contains name and schema fields
|
||||
"""
|
||||
|
|
@ -25,31 +25,28 @@ class SchemaManager:
|
|||
def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Optional[Mapping[str, Any]]:
|
||||
"""
|
||||
Get a specific schema by name
|
||||
|
||||
|
||||
Args:
|
||||
schema_name: Schema name
|
||||
version: Schema version, defaults to v1
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing name and schema, returns None if not found
|
||||
"""
|
||||
uri = f"https://dify.ai/schemas/{version}/{schema_name}.json"
|
||||
schema = self.registry.get_schema(uri)
|
||||
|
||||
|
||||
if schema:
|
||||
return {
|
||||
"name": schema_name,
|
||||
"schema": schema
|
||||
}
|
||||
return {"name": schema_name, "schema": schema}
|
||||
return None
|
||||
|
||||
def list_available_schemas(self, version: str = "v1") -> list[str]:
|
||||
"""
|
||||
List all available schema names for a specific version
|
||||
|
||||
|
||||
Args:
|
||||
version: Schema version, defaults to v1
|
||||
|
||||
|
||||
Returns:
|
||||
List of schema names
|
||||
"""
|
||||
|
|
@ -58,8 +55,8 @@ class SchemaManager:
|
|||
def list_available_versions(self) -> list[str]:
|
||||
"""
|
||||
List all available schema versions
|
||||
|
||||
|
||||
Returns:
|
||||
List of versions
|
||||
"""
|
||||
return self.registry.list_versions()
|
||||
return self.registry.list_versions()
|
||||
|
|
|
|||
|
|
@ -68,10 +68,10 @@ class VariablePool(BaseModel):
|
|||
# Add rag pipeline variables to the variable pool
|
||||
if self.rag_pipeline_variables:
|
||||
rag_pipeline_variables_map = defaultdict(dict)
|
||||
for var in self.rag_pipeline_variables:
|
||||
node_id = var.variable.belong_to_node_id
|
||||
key = var.variable.variable
|
||||
value = var.value
|
||||
for rag_var in self.rag_pipeline_variables:
|
||||
node_id = rag_var.variable.belong_to_node_id
|
||||
key = rag_var.variable.variable
|
||||
value = rag_var.value
|
||||
rag_pipeline_variables_map[node_id][key] = value
|
||||
for key, value in rag_pipeline_variables_map.items():
|
||||
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
|
||||
|
|
|
|||
|
|
@ -37,12 +37,14 @@ class NodeType(StrEnum):
|
|||
ANSWER = "answer"
|
||||
LLM = "llm"
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
KNOWLEDGE_INDEX = "knowledge-index"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TEMPLATE_TRANSFORM = "template-transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
DATASOURCE = "datasource"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
|
|
@ -83,6 +85,7 @@ class WorkflowType(StrEnum):
|
|||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
|
|
@ -116,6 +119,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
|||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class Graph:
|
|||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid].get("data", {})
|
||||
if node_data.get("type") == NodeType.START.value:
|
||||
if node_data.get("type") in [NodeType.START, NodeType.DATASOURCE]:
|
||||
start_node_id = nid
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -19,16 +19,14 @@ from core.file.enums import FileTransferMethod, FileType
|
|||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.tool.exc import ToolFileError
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
|
|
@ -39,7 +37,7 @@ from .entities import DatasourceNodeData
|
|||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
|
||||
|
||||
class DatasourceNode(BaseNode):
|
||||
class DatasourceNode(Node):
|
||||
"""
|
||||
Datasource Node
|
||||
"""
|
||||
|
|
@ -97,8 +95,8 @@ class DatasourceNode(BaseNode):
|
|||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
|
|
@ -172,8 +170,8 @@ class DatasourceNode(BaseNode):
|
|||
datasource_type=datasource_type,
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
|
|
@ -204,10 +202,10 @@ class DatasourceNode(BaseNode):
|
|||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
variable_pool.add([self.node_id, "file"], file_info)
|
||||
variable_pool.add([self._node_id, "file"], file_info)
|
||||
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
|
|
@ -220,8 +218,8 @@ class DatasourceNode(BaseNode):
|
|||
case _:
|
||||
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
|
|
@ -230,8 +228,8 @@ class DatasourceNode(BaseNode):
|
|||
)
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
|
|
@ -425,8 +423,10 @@ class DatasourceNode(BaseNode):
|
|||
elif message.type == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
|
|
@ -442,7 +442,11 @@ class DatasourceNode(BaseNode):
|
|||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
|
|
@ -454,17 +458,24 @@ class DatasourceNode(BaseNode):
|
|||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
# mark the end of the stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"json": json, "files": files, **variables, "text": text},
|
||||
metadata={
|
||||
|
|
@ -526,9 +537,9 @@ class DatasourceNode(BaseNode):
|
|||
tenant_id=self.tenant_id,
|
||||
)
|
||||
if file:
|
||||
variable_pool.add([self.node_id, "file"], file)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
variable_pool.add([self._node_id, "file"], file)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
|
|
|
|||
|
|
@ -9,16 +9,15 @@ from sqlalchemy import func
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
from ..base import BaseNode
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
KnowledgeIndexNodeError,
|
||||
|
|
@ -35,7 +34,7 @@ default_retrieval_model = {
|
|||
}
|
||||
|
||||
|
||||
class KnowledgeIndexNode(BaseNode):
|
||||
class KnowledgeIndexNode(Node):
|
||||
_node_data: KnowledgeIndexNodeData
|
||||
_node_type = NodeType.KNOWLEDGE_INDEX
|
||||
|
||||
|
|
@ -93,15 +92,12 @@ class KnowledgeIndexNode(BaseNode):
|
|||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data=None,
|
||||
outputs=outputs,
|
||||
)
|
||||
results = self._invoke_knowledge_index(
|
||||
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
|
||||
)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=results)
|
||||
|
||||
except KnowledgeIndexNodeError as e:
|
||||
logger.warning("Error when running knowledge index node")
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ class Dataset(Base):
|
|||
)
|
||||
|
||||
@property
|
||||
def doc_form(self):
|
||||
def doc_form(self) -> Optional[str]:
|
||||
if self.chunk_structure:
|
||||
return self.chunk_structure
|
||||
document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
|
||||
|
|
@ -424,7 +424,7 @@ class Document(Base):
|
|||
return status
|
||||
|
||||
@property
|
||||
def data_source_info_dict(self):
|
||||
def data_source_info_dict(self) -> dict[str, Any]:
|
||||
if self.data_source_info:
|
||||
try:
|
||||
data_source_info_dict = json.loads(self.data_source_info)
|
||||
|
|
@ -432,7 +432,7 @@ class Document(Base):
|
|||
data_source_info_dict = {}
|
||||
|
||||
return data_source_info_dict
|
||||
return None
|
||||
return {}
|
||||
|
||||
@property
|
||||
def data_source_detail_dict(self):
|
||||
|
|
|
|||
|
|
@ -52,3 +52,8 @@ class ToolProviderID(GenericProviderID):
|
|||
if self.organization == "langgenius":
|
||||
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
|
||||
self.plugin_name = f"{self.provider_name}_tool"
|
||||
|
||||
|
||||
class DatasourceProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
|
|
|
|||
|
|
@ -718,9 +718,9 @@ class DatasetService:
|
|||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=knowledge_configuration.embedding_model_provider,
|
||||
provider=knowledge_configuration.embedding_model_provider or "",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_configuration.embedding_model,
|
||||
model=knowledge_configuration.embedding_model or "",
|
||||
)
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
|
|
@ -1159,7 +1159,7 @@ class DocumentService:
|
|||
return
|
||||
documents = db.session.query(Document).where(Document.id.in_(document_ids)).all()
|
||||
file_ids = [
|
||||
document.data_source_info_dict["upload_file_id"]
|
||||
document.data_source_info_dict.get("upload_file_id", "")
|
||||
for document in documents
|
||||
if document.data_source_type == "upload_file"
|
||||
]
|
||||
|
|
@ -1281,7 +1281,7 @@ class DocumentService:
|
|||
account: Account | Any,
|
||||
dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
created_from: str = "web",
|
||||
):
|
||||
) -> tuple[list[Document], str]:
|
||||
# check doc_form
|
||||
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
|
||||
# check document limit
|
||||
|
|
@ -1386,7 +1386,7 @@ class DocumentService:
|
|||
"Invalid process rule mode: %s, can not find dataset process rule",
|
||||
process_rule.mode,
|
||||
)
|
||||
return
|
||||
return [], ""
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.flush()
|
||||
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
|
||||
|
|
@ -2595,7 +2595,9 @@ class SegmentService:
|
|||
return segment_data_list
|
||||
|
||||
@classmethod
|
||||
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
|
||||
def update_segment(
|
||||
cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset
|
||||
) -> DocumentSegment:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
|
|
@ -2764,6 +2766,8 @@ class SegmentService:
|
|||
segment.error = str(e)
|
||||
db.session.commit()
|
||||
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
|
||||
if not new_segment:
|
||||
raise ValueError("new_segment is not found")
|
||||
return new_segment
|
||||
|
||||
@classmethod
|
||||
|
|
@ -2804,7 +2808,11 @@ class SegmentService:
|
|||
index_node_ids = [seg.index_node_id for seg in segments]
|
||||
total_words = sum(seg.word_count for seg in segments)
|
||||
|
||||
document.word_count -= total_words
|
||||
if document.word_count is None:
|
||||
document.word_count = 0
|
||||
else:
|
||||
document.word_count = max(0, document.word_count - total_words)
|
||||
|
||||
db.session.add(document)
|
||||
|
||||
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from core.helper import encrypter
|
|||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.model_runtime.entities.provider_entities import FormType
|
||||
from core.plugin.entities.plugin import DatasourceProviderID
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
|
|
@ -19,6 +18,7 @@ from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncry
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -809,9 +809,7 @@ class DatasourceProviderService:
|
|||
credentials = self.list_datasource_credentials(
|
||||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
redirect_uri = (
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
|
||||
)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
|
||||
datasource_credentials.append(
|
||||
{
|
||||
"provider": datasource.provider,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
|
|
@ -110,7 +110,21 @@ class KnowledgeConfiguration(BaseModel):
|
|||
|
||||
chunk_structure: str
|
||||
indexing_technique: Literal["high_quality", "economy"]
|
||||
embedding_model_provider: Optional[str] = ""
|
||||
embedding_model: Optional[str] = ""
|
||||
embedding_model_provider: str = ""
|
||||
embedding_model: str = ""
|
||||
keyword_number: Optional[int] = 10
|
||||
retrieval_model: RetrievalSetting
|
||||
|
||||
@field_validator("embedding_model_provider", mode="before")
|
||||
@classmethod
|
||||
def validate_embedding_model_provider(cls, v):
|
||||
if v is None:
|
||||
return ""
|
||||
return v
|
||||
|
||||
@field_validator("embedding_model", mode="before")
|
||||
@classmethod
|
||||
def validate_embedding_model(cls, v):
|
||||
if v is None:
|
||||
return ""
|
||||
return v
|
||||
|
|
|
|||
|
|
@ -28,26 +28,23 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen
|
|||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||
from core.rag.entities.event import (
|
||||
BaseDatasourceEvent,
|
||||
DatasourceCompletedEvent,
|
||||
DatasourceErrorEvent,
|
||||
DatasourceProcessingEvent,
|
||||
)
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||
from core.workflow.nodes.event.types import NodeEvent
|
||||
from core.workflow.graph_events.base import GraphNodeEventBase
|
||||
from core.workflow.node_events.base import NodeRunResult
|
||||
from core.workflow.node_events.node import StreamCompletedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
@ -105,12 +102,13 @@ class RagPipelineService:
|
|||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
built_in_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
return built_in_result
|
||||
else:
|
||||
mode = "customized"
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
return result
|
||||
customized_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
return customized_result
|
||||
|
||||
@classmethod
|
||||
def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
|
||||
|
|
@ -471,7 +469,7 @@ class RagPipelineService:
|
|||
datasource_type: str,
|
||||
is_published: bool,
|
||||
credential_id: Optional[str] = None,
|
||||
) -> Generator[BaseDatasourceEvent, None, None]:
|
||||
) -> Generator[Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Run published workflow datasource
|
||||
"""
|
||||
|
|
@ -563,9 +561,9 @@ class RagPipelineService:
|
|||
user_id=account.id,
|
||||
request=OnlineDriveBrowseFilesRequest(
|
||||
bucket=user_inputs.get("bucket"),
|
||||
prefix=user_inputs.get("prefix"),
|
||||
prefix=user_inputs.get("prefix", ""),
|
||||
max_keys=user_inputs.get("max_keys", 20),
|
||||
start_after=user_inputs.get("start_after"),
|
||||
next_page_parameters=user_inputs.get("next_page_parameters"),
|
||||
),
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
|
|
@ -600,7 +598,7 @@ class RagPipelineService:
|
|||
end_time = time.time()
|
||||
if message.result.status == "completed":
|
||||
crawl_event = DatasourceCompletedEvent(
|
||||
data=message.result.web_info_list,
|
||||
data=message.result.web_info_list or [],
|
||||
total=message.result.total,
|
||||
completed=message.result.completed,
|
||||
time_consuming=round(end_time - start_time, 2),
|
||||
|
|
@ -681,9 +679,9 @@ class RagPipelineService:
|
|||
datasource_runtime.get_online_document_page_content(
|
||||
user_id=account.id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||
workspace_id=user_inputs.get("workspace_id"),
|
||||
page_id=user_inputs.get("page_id"),
|
||||
type=user_inputs.get("type"),
|
||||
workspace_id=user_inputs.get("workspace_id", ""),
|
||||
page_id=user_inputs.get("page_id", ""),
|
||||
type=user_inputs.get("type", ""),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
|
|
@ -740,7 +738,7 @@ class RagPipelineService:
|
|||
|
||||
def _handle_node_run_result(
|
||||
self,
|
||||
getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
|
||||
getter: Callable[[], tuple[Node, Generator[GraphNodeEventBase, None, None]]],
|
||||
start_at: float,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
|
|
@ -758,17 +756,16 @@ class RagPipelineService:
|
|||
|
||||
node_run_result: NodeRunResult | None = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunCompletedEvent):
|
||||
node_run_result = event.run_result
|
||||
|
||||
if isinstance(event, StreamCompletedEvent):
|
||||
node_run_result = event.node_run_result
|
||||
# sign output files
|
||||
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
|
||||
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {}
|
||||
break
|
||||
|
||||
if not node_run_result:
|
||||
raise ValueError("Node run failed with no run result")
|
||||
# single step debug mode error handling return
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.continue_on_error:
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.error_strategy:
|
||||
node_error_args: dict[str, Any] = {
|
||||
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
"error": node_run_result.error,
|
||||
|
|
@ -808,7 +805,7 @@ class RagPipelineService:
|
|||
workflow_id=node_instance.workflow_id,
|
||||
index=1,
|
||||
node_id=node_id,
|
||||
node_type=node_instance.type_,
|
||||
node_type=node_instance.node_type,
|
||||
title=node_instance.title,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
finished_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
|
|
@ -1148,7 +1145,7 @@ class RagPipelineService:
|
|||
.first()
|
||||
)
|
||||
return node_exec
|
||||
|
||||
|
||||
def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account | EndUser):
|
||||
# fetch draft workflow by app_model
|
||||
draft_workflow = self.get_draft_workflow(pipeline=pipeline)
|
||||
|
|
@ -1208,6 +1205,3 @@ class RagPipelineService:
|
|||
)
|
||||
session.commit()
|
||||
return workflow_node_execution_db_model
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ from core.helper import ssrf_proxy
|
|||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
|
|
@ -281,7 +281,7 @@ class RagPipelineDslService:
|
|||
icon = icon_info.icon
|
||||
icon_background = icon_info.icon_background
|
||||
icon_url = icon_info.icon_url
|
||||
else:
|
||||
else:
|
||||
icon_type = data.get("rag_pipeline", {}).get("icon_type")
|
||||
icon = data.get("rag_pipeline", {}).get("icon")
|
||||
icon_background = data.get("rag_pipeline", {}).get("icon_background")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
|
@ -87,7 +88,7 @@ class RagPipelineTransformService:
|
|||
"status": "success",
|
||||
}
|
||||
|
||||
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str):
|
||||
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: Optional[str]):
|
||||
if doc_form == "text_model":
|
||||
match datasource_type:
|
||||
case "upload_file":
|
||||
|
|
@ -148,7 +149,7 @@ class RagPipelineTransformService:
|
|||
return node
|
||||
|
||||
def _deal_knowledge_index(
|
||||
self, dataset: Dataset, doc_form: str, indexing_technique: str, retrieval_model: dict, node: dict
|
||||
self, dataset: Dataset, doc_form: str, indexing_technique: Optional[str], retrieval_model: dict, node: dict
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
|
@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]):
|
||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: Optional[str], file_ids: list[str]):
|
||||
"""
|
||||
Clean document when document deleted.
|
||||
:param document_ids: document ids
|
||||
|
|
@ -29,6 +30,8 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
|||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
if not doc_form:
|
||||
raise ValueError("doc_form is required")
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
|
|
|
|||
|
|
@ -21,14 +21,16 @@ from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
|||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def rag_pipeline_run_task(pipeline_id: str,
|
||||
application_generate_entity: dict,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
streaming: bool,
|
||||
workflow_execution_id: str | None = None,
|
||||
workflow_thread_pool_id: str | None = None):
|
||||
def rag_pipeline_run_task(
|
||||
pipeline_id: str,
|
||||
application_generate_entity: dict,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
streaming: bool,
|
||||
workflow_execution_id: str | None = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Async Run rag pipeline
|
||||
:param pipeline_id: Pipeline ID
|
||||
|
|
@ -94,18 +96,19 @@ def rag_pipeline_run_task(pipeline_id: str,
|
|||
with current_app.app_context():
|
||||
# Set the user directly in g for preserve_flask_contexts
|
||||
g._login_user = account
|
||||
|
||||
|
||||
# Copy context for thread (after setting user)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
|
||||
# Get Flask app object in the main thread where app context exists
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
|
||||
# Create a wrapper function that passes user context
|
||||
def _run_with_user_context():
|
||||
# Don't create a new app context here - let _generate handle it
|
||||
# Just ensure the user is available in contextvars
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
pipeline_generator._generate(
|
||||
flask_app=flask_app,
|
||||
|
|
@ -120,7 +123,7 @@ def rag_pipeline_run_task(pipeline_id: str,
|
|||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
|
||||
# Create and start worker thread
|
||||
worker_thread = threading.Thread(target=_run_with_user_context)
|
||||
worker_thread.start()
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
# Core schemas unit tests
|
||||
# Core schemas unit tests
|
||||
|
|
|
|||
|
|
@ -33,18 +33,16 @@ class TestSchemaResolver:
|
|||
|
||||
def test_simple_ref_resolution(self):
|
||||
"""Test resolving a simple $ref to a complete schema"""
|
||||
schema_with_ref = {
|
||||
"$ref": "https://dify.ai/schemas/v1/qa_structure.json"
|
||||
}
|
||||
|
||||
schema_with_ref = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
|
||||
|
||||
resolved = resolve_dify_schema_refs(schema_with_ref)
|
||||
|
||||
|
||||
# Should be resolved to the actual qa_structure schema
|
||||
assert resolved["type"] == "object"
|
||||
assert resolved["title"] == "Q&A Structure Schema"
|
||||
assert "qa_chunks" in resolved["properties"]
|
||||
assert resolved["properties"]["qa_chunks"]["type"] == "array"
|
||||
|
||||
|
||||
# Metadata fields should be removed
|
||||
assert "$id" not in resolved
|
||||
assert "$schema" not in resolved
|
||||
|
|
@ -55,29 +53,24 @@ class TestSchemaResolver:
|
|||
nested_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_data": {
|
||||
"$ref": "https://dify.ai/schemas/v1/file.json"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "string",
|
||||
"description": "Additional metadata"
|
||||
}
|
||||
}
|
||||
"file_data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
"metadata": {"type": "string", "description": "Additional metadata"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
resolved = resolve_dify_schema_refs(nested_schema)
|
||||
|
||||
|
||||
# Original structure should be preserved
|
||||
assert resolved["type"] == "object"
|
||||
assert "metadata" in resolved["properties"]
|
||||
assert resolved["properties"]["metadata"]["type"] == "string"
|
||||
|
||||
|
||||
# $ref should be resolved
|
||||
file_schema = resolved["properties"]["file_data"]
|
||||
assert file_schema["type"] == "object"
|
||||
assert file_schema["title"] == "File Schema"
|
||||
assert "name" in file_schema["properties"]
|
||||
|
||||
|
||||
# Metadata fields should be removed from resolved schema
|
||||
assert "$id" not in file_schema
|
||||
assert "$schema" not in file_schema
|
||||
|
|
@ -87,18 +80,16 @@ class TestSchemaResolver:
|
|||
"""Test resolving $refs in array items"""
|
||||
array_schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "https://dify.ai/schemas/v1/general_structure.json"
|
||||
},
|
||||
"description": "Array of general structures"
|
||||
"items": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"},
|
||||
"description": "Array of general structures",
|
||||
}
|
||||
|
||||
|
||||
resolved = resolve_dify_schema_refs(array_schema)
|
||||
|
||||
|
||||
# Array structure should be preserved
|
||||
assert resolved["type"] == "array"
|
||||
assert resolved["description"] == "Array of general structures"
|
||||
|
||||
|
||||
# Items $ref should be resolved
|
||||
items_schema = resolved["items"]
|
||||
assert items_schema["type"] == "array"
|
||||
|
|
@ -109,20 +100,16 @@ class TestSchemaResolver:
|
|||
external_ref_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"external_data": {
|
||||
"$ref": "https://example.com/external-schema.json"
|
||||
},
|
||||
"dify_data": {
|
||||
"$ref": "https://dify.ai/schemas/v1/file.json"
|
||||
}
|
||||
}
|
||||
"external_data": {"$ref": "https://example.com/external-schema.json"},
|
||||
"dify_data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
resolved = resolve_dify_schema_refs(external_ref_schema)
|
||||
|
||||
|
||||
# External $ref should remain unchanged
|
||||
assert resolved["properties"]["external_data"]["$ref"] == "https://example.com/external-schema.json"
|
||||
|
||||
|
||||
# Dify $ref should be resolved
|
||||
assert resolved["properties"]["dify_data"]["type"] == "object"
|
||||
assert resolved["properties"]["dify_data"]["title"] == "File Schema"
|
||||
|
|
@ -132,22 +119,14 @@ class TestSchemaResolver:
|
|||
simple_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name field"
|
||||
},
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "number"
|
||||
}
|
||||
}
|
||||
"name": {"type": "string", "description": "Name field"},
|
||||
"items": {"type": "array", "items": {"type": "number"}},
|
||||
},
|
||||
"required": ["name"]
|
||||
"required": ["name"],
|
||||
}
|
||||
|
||||
|
||||
resolved = resolve_dify_schema_refs(simple_schema)
|
||||
|
||||
|
||||
# Should be identical to input
|
||||
assert resolved == simple_schema
|
||||
assert resolved["type"] == "object"
|
||||
|
|
@ -159,21 +138,16 @@ class TestSchemaResolver:
|
|||
"""Test that excessive recursion depth is prevented"""
|
||||
# Create a moderately nested structure
|
||||
deep_schema = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
|
||||
|
||||
|
||||
# Wrap it in fewer layers to make the test more reasonable
|
||||
for _ in range(2):
|
||||
deep_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nested": deep_schema
|
||||
}
|
||||
}
|
||||
|
||||
deep_schema = {"type": "object", "properties": {"nested": deep_schema}}
|
||||
|
||||
# Should handle normal cases fine with reasonable depth
|
||||
resolved = resolve_dify_schema_refs(deep_schema, max_depth=25)
|
||||
assert resolved is not None
|
||||
assert resolved["type"] == "object"
|
||||
|
||||
|
||||
# Should raise error with very low max_depth
|
||||
with pytest.raises(MaxDepthExceededError) as exc_info:
|
||||
resolve_dify_schema_refs(deep_schema, max_depth=5)
|
||||
|
|
@ -185,12 +159,12 @@ class TestSchemaResolver:
|
|||
mock_registry = MagicMock()
|
||||
mock_registry.get_schema.side_effect = lambda uri: {
|
||||
"$ref": "https://dify.ai/schemas/v1/circular.json",
|
||||
"type": "object"
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
schema = {"$ref": "https://dify.ai/schemas/v1/circular.json"}
|
||||
resolved = resolve_dify_schema_refs(schema, registry=mock_registry)
|
||||
|
||||
|
||||
# Should mark circular reference
|
||||
assert "$circular_ref" in resolved
|
||||
|
||||
|
|
@ -199,10 +173,10 @@ class TestSchemaResolver:
|
|||
# Mock registry that returns None for unknown schemas
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_schema.return_value = None
|
||||
|
||||
|
||||
schema = {"$ref": "https://dify.ai/schemas/v1/unknown.json"}
|
||||
resolved = resolve_dify_schema_refs(schema, registry=mock_registry)
|
||||
|
||||
|
||||
# Should keep the original $ref when schema not found
|
||||
assert resolved["$ref"] == "https://dify.ai/schemas/v1/unknown.json"
|
||||
|
||||
|
|
@ -217,25 +191,25 @@ class TestSchemaResolver:
|
|||
def test_cache_functionality(self):
|
||||
"""Test that caching works correctly"""
|
||||
schema = {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
|
||||
|
||||
# First resolution should fetch from registry
|
||||
resolved1 = resolve_dify_schema_refs(schema)
|
||||
|
||||
|
||||
# Mock the registry to return different data
|
||||
with patch.object(self.registry, "get_schema") as mock_get:
|
||||
mock_get.return_value = {"type": "different"}
|
||||
|
||||
|
||||
# Second resolution should use cache
|
||||
resolved2 = resolve_dify_schema_refs(schema)
|
||||
|
||||
|
||||
# Should be the same as first resolution (from cache)
|
||||
assert resolved1 == resolved2
|
||||
# Mock should not have been called
|
||||
mock_get.assert_not_called()
|
||||
|
||||
|
||||
# Clear cache and try again
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
|
||||
# Now it should fetch again
|
||||
resolved3 = resolve_dify_schema_refs(schema)
|
||||
assert resolved3 == resolved1
|
||||
|
|
@ -244,14 +218,11 @@ class TestSchemaResolver:
|
|||
"""Test that the resolver is thread-safe"""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
for i in range(10)
|
||||
}
|
||||
"properties": {f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"} for i in range(10)},
|
||||
}
|
||||
|
||||
|
||||
results = []
|
||||
|
||||
|
||||
def resolve_in_thread():
|
||||
try:
|
||||
result = resolve_dify_schema_refs(schema)
|
||||
|
|
@ -260,12 +231,12 @@ class TestSchemaResolver:
|
|||
except Exception as e:
|
||||
results.append(e)
|
||||
return False
|
||||
|
||||
|
||||
# Run multiple threads concurrently
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(resolve_in_thread) for _ in range(20)]
|
||||
success = all(f.result() for f in futures)
|
||||
|
||||
|
||||
assert success
|
||||
# All results should be the same
|
||||
first_result = results[0]
|
||||
|
|
@ -276,10 +247,7 @@ class TestSchemaResolver:
|
|||
complex_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
},
|
||||
"files": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}},
|
||||
"nested": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
@ -290,21 +258,21 @@ class TestSchemaResolver:
|
|||
"type": "object",
|
||||
"properties": {
|
||||
"general": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
resolved = resolve_dify_schema_refs(complex_schema, max_depth=20)
|
||||
|
||||
|
||||
# Check structure is preserved
|
||||
assert resolved["type"] == "object"
|
||||
assert "files" in resolved["properties"]
|
||||
assert "nested" in resolved["properties"]
|
||||
|
||||
|
||||
# Check refs are resolved
|
||||
assert resolved["properties"]["files"]["items"]["type"] == "object"
|
||||
assert resolved["properties"]["files"]["items"]["title"] == "File Schema"
|
||||
|
|
@ -314,14 +282,14 @@ class TestSchemaResolver:
|
|||
|
||||
class TestUtilityFunctions:
|
||||
"""Test utility functions"""
|
||||
|
||||
|
||||
def test_is_dify_schema_ref(self):
|
||||
"""Test _is_dify_schema_ref function"""
|
||||
# Valid Dify refs
|
||||
assert _is_dify_schema_ref("https://dify.ai/schemas/v1/file.json")
|
||||
assert _is_dify_schema_ref("https://dify.ai/schemas/v2/complex_name.json")
|
||||
assert _is_dify_schema_ref("https://dify.ai/schemas/v999/test-file.json")
|
||||
|
||||
|
||||
# Invalid refs
|
||||
assert not _is_dify_schema_ref("https://example.com/schema.json")
|
||||
assert not _is_dify_schema_ref("https://dify.ai/other/path.json")
|
||||
|
|
@ -330,61 +298,46 @@ class TestUtilityFunctions:
|
|||
assert not _is_dify_schema_ref(None)
|
||||
assert not _is_dify_schema_ref(123)
|
||||
assert not _is_dify_schema_ref(["list"])
|
||||
|
||||
|
||||
def test_has_dify_refs(self):
|
||||
"""Test _has_dify_refs function"""
|
||||
# Schemas with Dify refs
|
||||
assert _has_dify_refs({"$ref": "https://dify.ai/schemas/v1/file.json"})
|
||||
assert _has_dify_refs({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
assert _has_dify_refs(
|
||||
{"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}}
|
||||
)
|
||||
assert _has_dify_refs([{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/file.json"}])
|
||||
assert _has_dify_refs(
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {"nested": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}},
|
||||
},
|
||||
}
|
||||
})
|
||||
assert _has_dify_refs([
|
||||
{"type": "string"},
|
||||
{"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
])
|
||||
assert _has_dify_refs({
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nested": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
)
|
||||
|
||||
# Schemas without Dify refs
|
||||
assert not _has_dify_refs({"type": "string"})
|
||||
assert not _has_dify_refs({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "number"}
|
||||
}
|
||||
})
|
||||
assert not _has_dify_refs([
|
||||
{"type": "string"},
|
||||
{"type": "number"},
|
||||
{"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
])
|
||||
|
||||
assert not _has_dify_refs(
|
||||
{"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "number"}}}
|
||||
)
|
||||
assert not _has_dify_refs(
|
||||
[{"type": "string"}, {"type": "number"}, {"type": "object", "properties": {"name": {"type": "string"}}}]
|
||||
)
|
||||
|
||||
# Schemas with non-Dify refs (should return False)
|
||||
assert not _has_dify_refs({"$ref": "https://example.com/schema.json"})
|
||||
assert not _has_dify_refs({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"external": {"$ref": "https://example.com/external.json"}
|
||||
}
|
||||
})
|
||||
|
||||
assert not _has_dify_refs(
|
||||
{"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}}
|
||||
)
|
||||
|
||||
# Primitive types
|
||||
assert not _has_dify_refs("string")
|
||||
assert not _has_dify_refs(123)
|
||||
assert not _has_dify_refs(True)
|
||||
assert not _has_dify_refs(None)
|
||||
|
||||
|
||||
def test_has_dify_refs_hybrid_vs_recursive(self):
|
||||
"""Test that hybrid and recursive detection give same results"""
|
||||
test_schemas = [
|
||||
|
|
@ -392,29 +345,13 @@ class TestUtilityFunctions:
|
|||
{"type": "string"},
|
||||
{"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
[{"type": "string"}, {"type": "number"}],
|
||||
|
||||
# With Dify refs
|
||||
# With Dify refs
|
||||
{"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
}
|
||||
},
|
||||
[
|
||||
{"type": "string"},
|
||||
{"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
|
||||
],
|
||||
|
||||
{"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}},
|
||||
[{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}],
|
||||
# With non-Dify refs
|
||||
{"$ref": "https://example.com/schema.json"},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"external": {"$ref": "https://example.com/external.json"}
|
||||
}
|
||||
},
|
||||
|
||||
{"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}},
|
||||
# Complex nested
|
||||
{
|
||||
"type": "object",
|
||||
|
|
@ -422,41 +359,40 @@ class TestUtilityFunctions:
|
|||
"level1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level2": {
|
||||
"type": "array",
|
||||
"items": {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
}
|
||||
}
|
||||
"level2": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
# Edge cases
|
||||
{"description": "This mentions $ref but is not a reference"},
|
||||
{"$ref": "not-a-url"},
|
||||
|
||||
# Primitive types
|
||||
"string", 123, True, None, []
|
||||
"string",
|
||||
123,
|
||||
True,
|
||||
None,
|
||||
[],
|
||||
]
|
||||
|
||||
|
||||
for schema in test_schemas:
|
||||
hybrid_result = _has_dify_refs_hybrid(schema)
|
||||
recursive_result = _has_dify_refs_recursive(schema)
|
||||
|
||||
|
||||
assert hybrid_result == recursive_result, f"Mismatch for schema: {schema}"
|
||||
|
||||
|
||||
def test_parse_dify_schema_uri(self):
|
||||
"""Test parse_dify_schema_uri function"""
|
||||
# Valid URIs
|
||||
assert parse_dify_schema_uri("https://dify.ai/schemas/v1/file.json") == ("v1", "file")
|
||||
assert parse_dify_schema_uri("https://dify.ai/schemas/v2/complex_name.json") == ("v2", "complex_name")
|
||||
assert parse_dify_schema_uri("https://dify.ai/schemas/v999/test-file.json") == ("v999", "test-file")
|
||||
|
||||
|
||||
# Invalid URIs
|
||||
assert parse_dify_schema_uri("https://example.com/schema.json") == ("", "")
|
||||
assert parse_dify_schema_uri("invalid") == ("", "")
|
||||
assert parse_dify_schema_uri("") == ("", "")
|
||||
|
||||
|
||||
def test_remove_metadata_fields(self):
|
||||
"""Test _remove_metadata_fields function"""
|
||||
schema = {
|
||||
|
|
@ -465,68 +401,68 @@ class TestUtilityFunctions:
|
|||
"version": "should be removed",
|
||||
"type": "object",
|
||||
"title": "should remain",
|
||||
"properties": {}
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
|
||||
cleaned = _remove_metadata_fields(schema)
|
||||
|
||||
|
||||
assert "$id" not in cleaned
|
||||
assert "$schema" not in cleaned
|
||||
assert "version" not in cleaned
|
||||
assert cleaned["type"] == "object"
|
||||
assert cleaned["title"] == "should remain"
|
||||
assert "properties" in cleaned
|
||||
|
||||
|
||||
# Original should be unchanged
|
||||
assert "$id" in schema
|
||||
|
||||
|
||||
class TestSchemaResolverClass:
|
||||
"""Test SchemaResolver class specifically"""
|
||||
|
||||
|
||||
def test_resolver_initialization(self):
|
||||
"""Test resolver initialization"""
|
||||
# Default initialization
|
||||
resolver = SchemaResolver()
|
||||
assert resolver.max_depth == 10
|
||||
assert resolver.registry is not None
|
||||
|
||||
|
||||
# Custom initialization
|
||||
custom_registry = MagicMock()
|
||||
resolver = SchemaResolver(registry=custom_registry, max_depth=5)
|
||||
assert resolver.max_depth == 5
|
||||
assert resolver.registry is custom_registry
|
||||
|
||||
|
||||
def test_cache_sharing(self):
|
||||
"""Test that cache is shared between resolver instances"""
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
|
||||
schema = {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
|
||||
|
||||
# First resolver populates cache
|
||||
resolver1 = SchemaResolver()
|
||||
result1 = resolver1.resolve(schema)
|
||||
|
||||
|
||||
# Second resolver should use the same cache
|
||||
resolver2 = SchemaResolver()
|
||||
with patch.object(resolver2.registry, "get_schema") as mock_get:
|
||||
result2 = resolver2.resolve(schema)
|
||||
# Should not call registry since it's in cache
|
||||
mock_get.assert_not_called()
|
||||
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
|
||||
def test_resolver_with_list_schema(self):
|
||||
"""Test resolver with list as root schema"""
|
||||
list_schema = [
|
||||
{"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
{"type": "string"},
|
||||
{"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
|
||||
{"$ref": "https://dify.ai/schemas/v1/qa_structure.json"},
|
||||
]
|
||||
|
||||
|
||||
resolver = SchemaResolver()
|
||||
resolved = resolver.resolve(list_schema)
|
||||
|
||||
|
||||
assert isinstance(resolved, list)
|
||||
assert len(resolved) == 3
|
||||
assert resolved[0]["type"] == "object"
|
||||
|
|
@ -534,20 +470,20 @@ class TestSchemaResolverClass:
|
|||
assert resolved[1] == {"type": "string"}
|
||||
assert resolved[2]["type"] == "object"
|
||||
assert resolved[2]["title"] == "Q&A Structure Schema"
|
||||
|
||||
|
||||
def test_cache_performance(self):
|
||||
"""Test that caching improves performance"""
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
|
||||
# Create a schema with many references to the same schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
for i in range(50) # Reduced to avoid depth issues
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# First run (no cache) - run multiple times to warm up
|
||||
results1 = []
|
||||
for _ in range(3):
|
||||
|
|
@ -556,9 +492,9 @@ class TestSchemaResolverClass:
|
|||
result1 = resolve_dify_schema_refs(schema)
|
||||
time_no_cache = time.perf_counter() - start
|
||||
results1.append(time_no_cache)
|
||||
|
||||
|
||||
avg_time_no_cache = sum(results1) / len(results1)
|
||||
|
||||
|
||||
# Second run (with cache) - run multiple times
|
||||
results2 = []
|
||||
for _ in range(3):
|
||||
|
|
@ -566,14 +502,14 @@ class TestSchemaResolverClass:
|
|||
result2 = resolve_dify_schema_refs(schema)
|
||||
time_with_cache = time.perf_counter() - start
|
||||
results2.append(time_with_cache)
|
||||
|
||||
|
||||
avg_time_with_cache = sum(results2) / len(results2)
|
||||
|
||||
|
||||
# Cache should make it faster (more lenient check)
|
||||
assert result1 == result2
|
||||
# Cache should provide some performance benefit
|
||||
assert avg_time_with_cache <= avg_time_no_cache
|
||||
|
||||
|
||||
def test_fast_path_performance_no_refs(self):
|
||||
"""Test that schemas without $refs use fast path and avoid deep copying"""
|
||||
# Create a moderately complex schema without any $refs (typical plugin output_schema)
|
||||
|
|
@ -585,16 +521,13 @@ class TestSchemaResolverClass:
|
|||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"value": {"type": "number"},
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"}
|
||||
}
|
||||
}
|
||||
"items": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
for i in range(50)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Measure fast path (no refs) performance
|
||||
fast_times = []
|
||||
for _ in range(10):
|
||||
|
|
@ -602,21 +535,21 @@ class TestSchemaResolverClass:
|
|||
result_fast = resolve_dify_schema_refs(no_refs_schema)
|
||||
elapsed = time.perf_counter() - start
|
||||
fast_times.append(elapsed)
|
||||
|
||||
|
||||
avg_fast_time = sum(fast_times) / len(fast_times)
|
||||
|
||||
|
||||
# Most importantly: result should be identical to input (no copying)
|
||||
assert result_fast is no_refs_schema
|
||||
|
||||
|
||||
# Create schema with $refs for comparison (same structure size)
|
||||
with_refs_schema = {
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
f"property_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
for i in range(20) # Fewer to avoid depth issues but still comparable
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Measure slow path (with refs) performance
|
||||
SchemaResolver.clear_cache()
|
||||
slow_times = []
|
||||
|
|
@ -626,63 +559,54 @@ class TestSchemaResolverClass:
|
|||
result_slow = resolve_dify_schema_refs(with_refs_schema, max_depth=50)
|
||||
elapsed = time.perf_counter() - start
|
||||
slow_times.append(elapsed)
|
||||
|
||||
|
||||
avg_slow_time = sum(slow_times) / len(slow_times)
|
||||
|
||||
|
||||
# The key benefit: fast path should be reasonably fast (main goal is no deep copy)
|
||||
# and definitely avoid the expensive BFS resolution
|
||||
# Even if detection has some overhead, it should still be faster for typical cases
|
||||
print(f"Fast path (no refs): {avg_fast_time:.6f}s")
|
||||
print(f"Slow path (with refs): {avg_slow_time:.6f}s")
|
||||
|
||||
|
||||
# More lenient check: fast path should be at least somewhat competitive
|
||||
# The main benefit is avoiding deep copy and BFS, not necessarily being 5x faster
|
||||
assert avg_fast_time < avg_slow_time * 2 # Should not be more than 2x slower
|
||||
|
||||
|
||||
def test_batch_processing_performance(self):
|
||||
"""Test performance improvement for batch processing of schemas without refs"""
|
||||
# Simulate the plugin tool scenario: many schemas, most without refs
|
||||
schemas_without_refs = [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
f"field_{j}": {"type": "string" if j % 2 else "number"}
|
||||
for j in range(10)
|
||||
}
|
||||
"properties": {f"field_{j}": {"type": "string" if j % 2 else "number"} for j in range(10)},
|
||||
}
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
|
||||
# Test batch processing performance
|
||||
start = time.perf_counter()
|
||||
results = [resolve_dify_schema_refs(schema) for schema in schemas_without_refs]
|
||||
batch_time = time.perf_counter() - start
|
||||
|
||||
|
||||
# Verify all results are identical to inputs (fast path used)
|
||||
for original, result in zip(schemas_without_refs, results):
|
||||
assert result is original
|
||||
|
||||
|
||||
# Should be very fast - each schema should take < 0.001 seconds on average
|
||||
avg_time_per_schema = batch_time / len(schemas_without_refs)
|
||||
assert avg_time_per_schema < 0.001
|
||||
|
||||
|
||||
def test_has_dify_refs_performance(self):
|
||||
"""Test that _has_dify_refs is fast for large schemas without refs"""
|
||||
# Create a very large schema without refs
|
||||
large_schema = {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
|
||||
large_schema = {"type": "object", "properties": {}}
|
||||
|
||||
# Add many nested properties
|
||||
current = large_schema
|
||||
for i in range(100):
|
||||
current["properties"][f"level_{i}"] = {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
current["properties"][f"level_{i}"] = {"type": "object", "properties": {}}
|
||||
current = current["properties"][f"level_{i}"]
|
||||
|
||||
|
||||
# _has_dify_refs should be fast even for large schemas
|
||||
times = []
|
||||
for _ in range(50):
|
||||
|
|
@ -690,13 +614,13 @@ class TestSchemaResolverClass:
|
|||
has_refs = _has_dify_refs(large_schema)
|
||||
elapsed = time.perf_counter() - start
|
||||
times.append(elapsed)
|
||||
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
|
||||
|
||||
# Should be False and fast
|
||||
assert not has_refs
|
||||
assert avg_time < 0.01 # Should complete in less than 10ms
|
||||
|
||||
|
||||
def test_hybrid_vs_recursive_performance(self):
|
||||
"""Test performance comparison between hybrid and recursive detection"""
|
||||
# Create test schemas of different types and sizes
|
||||
|
|
@ -704,16 +628,9 @@ class TestSchemaResolverClass:
|
|||
# Case 1: Small schema without refs (most common case)
|
||||
{
|
||||
"name": "small_no_refs",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"value": {"type": "number"}
|
||||
}
|
||||
},
|
||||
"expected": False
|
||||
"schema": {"type": "object", "properties": {"name": {"type": "string"}, "value": {"type": "number"}}},
|
||||
"expected": False,
|
||||
},
|
||||
|
||||
# Case 2: Medium schema without refs
|
||||
{
|
||||
"name": "medium_no_refs",
|
||||
|
|
@ -725,28 +642,16 @@ class TestSchemaResolverClass:
|
|||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"value": {"type": "number"},
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"}
|
||||
}
|
||||
}
|
||||
"items": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
for i in range(20)
|
||||
}
|
||||
},
|
||||
},
|
||||
"expected": False
|
||||
"expected": False,
|
||||
},
|
||||
|
||||
# Case 3: Large schema without refs
|
||||
{
|
||||
"name": "large_no_refs",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
},
|
||||
"expected": False
|
||||
},
|
||||
|
||||
{"name": "large_no_refs", "schema": {"type": "object", "properties": {}}, "expected": False},
|
||||
# Case 4: Schema with Dify refs
|
||||
{
|
||||
"name": "with_dify_refs",
|
||||
|
|
@ -754,45 +659,38 @@ class TestSchemaResolverClass:
|
|||
"type": "object",
|
||||
"properties": {
|
||||
"file": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
"data": {"type": "string"}
|
||||
}
|
||||
"data": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"expected": True
|
||||
"expected": True,
|
||||
},
|
||||
|
||||
# Case 5: Schema with non-Dify refs
|
||||
{
|
||||
"name": "with_external_refs",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"external": {"$ref": "https://example.com/schema.json"},
|
||||
"data": {"type": "string"}
|
||||
}
|
||||
"type": "object",
|
||||
"properties": {"external": {"$ref": "https://example.com/schema.json"}, "data": {"type": "string"}},
|
||||
},
|
||||
"expected": False
|
||||
}
|
||||
"expected": False,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Add deep nesting to large schema
|
||||
current = test_cases[2]["schema"]
|
||||
for i in range(50):
|
||||
current["properties"][f"level_{i}"] = {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
current["properties"][f"level_{i}"] = {"type": "object", "properties": {}}
|
||||
current = current["properties"][f"level_{i}"]
|
||||
|
||||
|
||||
# Performance comparison
|
||||
for test_case in test_cases:
|
||||
schema = test_case["schema"]
|
||||
expected = test_case["expected"]
|
||||
name = test_case["name"]
|
||||
|
||||
|
||||
# Test correctness first
|
||||
assert _has_dify_refs_hybrid(schema) == expected
|
||||
assert _has_dify_refs_recursive(schema) == expected
|
||||
|
||||
|
||||
# Measure hybrid performance
|
||||
hybrid_times = []
|
||||
for _ in range(10):
|
||||
|
|
@ -800,7 +698,7 @@ class TestSchemaResolverClass:
|
|||
result_hybrid = _has_dify_refs_hybrid(schema)
|
||||
elapsed = time.perf_counter() - start
|
||||
hybrid_times.append(elapsed)
|
||||
|
||||
|
||||
# Measure recursive performance
|
||||
recursive_times = []
|
||||
for _ in range(10):
|
||||
|
|
@ -808,69 +706,62 @@ class TestSchemaResolverClass:
|
|||
result_recursive = _has_dify_refs_recursive(schema)
|
||||
elapsed = time.perf_counter() - start
|
||||
recursive_times.append(elapsed)
|
||||
|
||||
|
||||
avg_hybrid = sum(hybrid_times) / len(hybrid_times)
|
||||
avg_recursive = sum(recursive_times) / len(recursive_times)
|
||||
|
||||
|
||||
print(f"{name}: hybrid={avg_hybrid:.6f}s, recursive={avg_recursive:.6f}s")
|
||||
|
||||
|
||||
# Results should be identical
|
||||
assert result_hybrid == result_recursive == expected
|
||||
|
||||
|
||||
# For schemas without refs, hybrid should be competitive or better
|
||||
if not expected: # No refs case
|
||||
# Hybrid might be slightly slower due to JSON serialization overhead,
|
||||
# but should not be dramatically worse
|
||||
assert avg_hybrid < avg_recursive * 5 # At most 5x slower
|
||||
|
||||
|
||||
def test_string_matching_edge_cases(self):
|
||||
"""Test edge cases for string-based detection"""
|
||||
# Case 1: False positive potential - $ref in description
|
||||
schema_false_positive = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "This field explains how $ref works in JSON Schema"
|
||||
}
|
||||
}
|
||||
"description": {"type": "string", "description": "This field explains how $ref works in JSON Schema"}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Both methods should return False
|
||||
assert not _has_dify_refs_hybrid(schema_false_positive)
|
||||
assert not _has_dify_refs_recursive(schema_false_positive)
|
||||
|
||||
|
||||
# Case 2: Complex URL patterns
|
||||
complex_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dify_url": {
|
||||
"type": "string",
|
||||
"default": "https://dify.ai/schemas/info"
|
||||
},
|
||||
"actual_ref": {
|
||||
"$ref": "https://dify.ai/schemas/v1/file.json"
|
||||
}
|
||||
}
|
||||
"dify_url": {"type": "string", "default": "https://dify.ai/schemas/info"},
|
||||
"actual_ref": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Both methods should return True (due to actual_ref)
|
||||
assert _has_dify_refs_hybrid(complex_schema)
|
||||
assert _has_dify_refs_recursive(complex_schema)
|
||||
|
||||
|
||||
# Case 3: Non-JSON serializable objects (should fall back to recursive)
|
||||
import datetime
|
||||
|
||||
non_serializable = {
|
||||
"type": "object",
|
||||
"timestamp": datetime.datetime.now(),
|
||||
"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
"data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
}
|
||||
|
||||
|
||||
# Hybrid should fall back to recursive and still work
|
||||
assert _has_dify_refs_hybrid(non_serializable)
|
||||
assert _has_dify_refs_recursive(non_serializable)
|
||||
assert _has_dify_refs_recursive(non_serializable)
|
||||
|
|
|
|||
Loading…
Reference in New Issue