Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2026-04-12 10:16:13 +08:00
commit e4c056a57a
381 changed files with 29436 additions and 8157 deletions

View File

@ -7,6 +7,7 @@
## Summary
<!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. -->
<!-- If this PR was created by an automated agent, add `From <Tool Name>` as the final line of the description. Example: `From Codex`. -->
## Screenshots

View File

@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process.
## Getting Help
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
## Automated Agent Contributions
> [!NOTE]
> If you are an automated agent just add 🤖🤖🤖 to the end of the PR title to opt-in.

View File

@ -1,5 +1,5 @@
import os
from typing import Any, Literal
from typing import Any, Literal, TypedDict
from urllib.parse import parse_qsl, quote_plus
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
@ -107,6 +107,17 @@ class KeywordStoreConfig(BaseSettings):
)
class SQLAlchemyEngineOptionsDict(TypedDict):
pool_size: int
max_overflow: int
pool_recycle: int
pool_pre_ping: bool
connect_args: dict[str, str]
pool_use_lifo: bool
pool_reset_on_return: None
pool_timeout: int
class DatabaseConfig(BaseSettings):
# Database type selector
DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
@ -209,11 +220,11 @@ class DatabaseConfig(BaseSettings):
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
def SQLALCHEMY_ENGINE_OPTIONS(self) -> SQLAlchemyEngineOptionsDict:
# Parse DB_EXTRAS for 'options'
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
options = db_extras_dict.get("options", "")
connect_args = {}
connect_args: dict[str, str] = {}
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
timezone_opt = "-c timezone=UTC"
@ -223,7 +234,7 @@ class DatabaseConfig(BaseSettings):
merged_options = timezone_opt
connect_args = {"options": merged_options}
return {
result: SQLAlchemyEngineOptionsDict = {
"pool_size": self.SQLALCHEMY_POOL_SIZE,
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
@ -233,6 +244,7 @@ class DatabaseConfig(BaseSettings):
"pool_reset_on_return": None,
"pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT,
}
return result
class CeleryConfig(DatabaseConfig):

View File

@ -1,9 +1,11 @@
import json
from typing import cast
from typing import Any, cast
from flask import request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@ -18,30 +20,30 @@ from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
class ModelConfigRequest(BaseModel):
provider: str | None = Field(default=None, description="Model provider")
model: str | None = Field(default=None, description="Model name")
configs: dict[str, Any] | None = Field(default=None, description="Model configuration parameters")
opening_statement: str | None = Field(default=None, description="Opening statement")
suggested_questions: list[str] | None = Field(default=None, description="Suggested questions")
more_like_this: dict[str, Any] | None = Field(default=None, description="More like this configuration")
speech_to_text: dict[str, Any] | None = Field(default=None, description="Speech to text configuration")
text_to_speech: dict[str, Any] | None = Field(default=None, description="Text to speech configuration")
retrieval_model: dict[str, Any] | None = Field(default=None, description="Retrieval model configuration")
tools: list[dict[str, Any]] | None = Field(default=None, description="Available tools")
dataset_configs: dict[str, Any] | None = Field(default=None, description="Dataset configurations")
agent_mode: dict[str, Any] | None = Field(default=None, description="Agent mode configuration")
register_schema_models(console_ns, ModelConfigRequest)
@console_ns.route("/apps/<uuid:app_id>/model-config")
class ModelConfigResource(Resource):
@console_ns.doc("update_app_model_config")
@console_ns.doc(description="Update application model configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"ModelConfigRequest",
{
"provider": fields.String(description="Model provider"),
"model": fields.String(description="Model name"),
"configs": fields.Raw(description="Model configuration parameters"),
"opening_statement": fields.String(description="Opening statement"),
"suggested_questions": fields.List(fields.String(), description="Suggested questions"),
"more_like_this": fields.Raw(description="More like this configuration"),
"speech_to_text": fields.Raw(description="Speech to text configuration"),
"text_to_speech": fields.Raw(description="Text to speech configuration"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"tools": fields.List(fields.Raw(), description="Available tools"),
"dataset_configs": fields.Raw(description="Dataset configurations"),
"agent_mode": fields.Raw(description="Agent mode configuration"),
},
)
)
@console_ns.expect(console_ns.models[ModelConfigRequest.__name__])
@console_ns.response(200, "Model configuration updated successfully")
@console_ns.response(400, "Invalid configuration")
@console_ns.response(404, "App not found")

View File

@ -2,18 +2,17 @@ import base64
from typing import Literal
from flask import request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel):
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
@ -24,8 +23,7 @@ class PartnerTenantsPayload(BaseModel):
click_id: str = Field(..., description="Click Id from partner referral link")
for model in (SubscriptionQuery, PartnerTenantsPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
register_schema_models(console_ns, SubscriptionQuery, PartnerTenantsPayload)
@console_ns.route("/billing/subscription")
@ -58,12 +56,7 @@ class PartnerTenants(Resource):
@console_ns.doc("sync_partner_tenants_bindings")
@console_ns.doc(description="Sync partner tenants bindings")
@console_ns.doc(params={"partner_key": "Partner key"})
@console_ns.expect(
console_ns.model(
"SyncPartnerTenantsBindingsRequest",
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
)
)
@console_ns.expect(console_ns.models[PartnerTenantsPayload.__name__])
@console_ns.response(200, "Tenants synced to partner successfully")
@console_ns.response(400, "Invalid partner information")
@setup_required

View File

@ -162,7 +162,9 @@ class DataSourceApi(Resource):
binding_id = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_tenant_id
)
).scalar_one_or_none()
if data_source_binding is None:
raise NotFound("Data source binding not found.")
@ -222,11 +224,11 @@ class DataSourceNotionListApi(Resource):
raise ValueError("Dataset is not notion type.")
documents = session.scalars(
select(Document).filter_by(
dataset_id=query.dataset_id,
tenant_id=current_tenant_id,
data_source_type="notion_import",
enabled=True,
select(Document).where(
Document.dataset_id == query.dataset_id,
Document.tenant_id == current_tenant_id,
Document.data_source_type == "notion_import",
Document.enabled.is_(True),
)
).all()
if documents:

View File

@ -280,7 +280,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
query = select(Document).where(Document.dataset_id == str(dataset_id), Document.tenant_id == current_tenant_id)
if status:
query = DocumentService.apply_display_status_filter(query, status)

View File

@ -527,7 +527,7 @@ class DocumentListApi(DatasetApiResource):
if not dataset:
raise NotFound("Dataset not found.")
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
query = select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == tenant_id)
if query_params.status:
query = DocumentService.apply_display_status_filter(query, query_params.status)

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import enum
from enum import StrEnum
from typing import Any
from typing import Any, TypedDict
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from yarl import URL
@ -179,6 +179,12 @@ class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
datasources: list[DatasourceEntity] = Field(default_factory=list)
class DatasourceInvokeMetaDict(TypedDict):
time_cost: float
error: str | None
tool_config: dict[str, Any] | None
class DatasourceInvokeMeta(BaseModel):
"""
Datasource invoke meta
@ -202,12 +208,13 @@ class DatasourceInvokeMeta(BaseModel):
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self) -> dict:
return {
def to_dict(self) -> DatasourceInvokeMetaDict:
result: DatasourceInvokeMetaDict = {
"time_cost": self.time_cost,
"error": self.error,
"tool_config": self.tool_config,
}
return result
class DatasourceLabel(BaseModel):

View File

@ -2,7 +2,7 @@ import json
import os
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypedDict, cast
from graphon.file import file_manager
from graphon.model_runtime.entities.message_entities import (
@ -34,6 +34,13 @@ class ModelMode(StrEnum):
prompt_file_contents: dict[str, Any] = {}
class PromptTemplateConfigDict(TypedDict):
prompt_template: PromptTemplateParser
custom_variable_keys: list[str]
special_variable_keys: list[str]
prompt_rules: dict[str, Any]
class SimplePromptTransform(PromptTransform):
"""
Simple Prompt Transform for Chatbot App Basic Mode.
@ -105,18 +112,13 @@ class SimplePromptTransform(PromptTransform):
with_memory_prompt=histories is not None,
)
custom_variable_keys_obj = prompt_template_config["custom_variable_keys"]
special_variable_keys_obj = prompt_template_config["special_variable_keys"]
custom_variable_keys = prompt_template_config["custom_variable_keys"]
if not isinstance(custom_variable_keys, list):
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys)}")
# Type check for custom_variable_keys
if not isinstance(custom_variable_keys_obj, list):
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}")
custom_variable_keys = cast(list[str], custom_variable_keys_obj)
# Type check for special_variable_keys
if not isinstance(special_variable_keys_obj, list):
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}")
special_variable_keys = cast(list[str], special_variable_keys_obj)
special_variable_keys = prompt_template_config["special_variable_keys"]
if not isinstance(special_variable_keys, list):
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys)}")
variables = {k: inputs[k] for k in custom_variable_keys if k in inputs}
@ -150,7 +152,7 @@ class SimplePromptTransform(PromptTransform):
has_context: bool,
query_in_prompt: bool,
with_memory_prompt: bool = False,
) -> dict[str, object]:
) -> PromptTemplateConfigDict:
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
custom_variable_keys: list[str] = []
@ -173,12 +175,13 @@ class SimplePromptTransform(PromptTransform):
prompt += prompt_rules.get("query_prompt", "{{#query#}}")
special_variable_keys.append("#query#")
return {
result: PromptTemplateConfigDict = {
"prompt_template": PromptTemplateParser(template=prompt),
"custom_variable_keys": custom_variable_keys,
"special_variable_keys": special_variable_keys,
"prompt_rules": prompt_rules,
}
return result
def _get_chat_model_prompt_messages(
self,

View File

@ -6,7 +6,7 @@ from collections.abc import Mapping
from typing import Any
from flask import current_app
from sqlalchemy import delete, func, select
from sqlalchemy import delete, func, select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
@ -63,11 +63,11 @@ class IndexProcessor:
summary_index_setting: SummaryIndexSettingDict | None = None,
) -> IndexingResultDict:
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if not document:
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
@ -104,12 +104,12 @@ class IndexProcessor:
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.word_count = (
session.query(func.sum(DocumentSegment.word_count))
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
session.scalar(
select(func.sum(DocumentSegment.word_count)).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
)
)
.scalar()
) or 0
# Update need_summary based on dataset's summary_index_setting
if summary_index_setting and summary_index_setting.get("enable") is True:
@ -118,15 +118,17 @@ class IndexProcessor:
document.need_summary = False
session.add(document)
# update document segment status
session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
session.execute(
update(DocumentSegment)
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
)
.values(
status="completed",
enabled=True,
completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
)
)
result: IndexingResultDict = {
@ -151,11 +153,11 @@ class IndexProcessor:
doc_language = None
with session_factory.create_session() as session:
if document_id:
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
else:
document = None
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")

View File

@ -159,14 +159,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if node_ids:
# Find segments by index_node_id
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment)
.filter(
segments = session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
).all()
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)

View File

@ -8,6 +8,7 @@ from typing import Any, TypedDict
import pandas as pd
from flask import Flask, current_app
from sqlalchemy import select
from werkzeug.datastructures import FileStorage
from core.db.session_factory import session_factory
@ -163,14 +164,12 @@ class QAIndexProcessor(BaseIndexProcessor):
if node_ids:
# Find segments by index_node_id
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment)
.filter(
segments = session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
).all()
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)

View File

@ -14,7 +14,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMU
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy import and_, func, literal, or_, select, update
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import (
@ -276,8 +276,8 @@ class DatasetRetrieval:
document_ids = [i.segment.document_id for i in records]
with session_factory.create_session() as session:
datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all()
datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
documents = session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all()
dataset_map = {i.id: i for i in datasets}
document_map = {i.id: i for i in documents}
@ -971,9 +971,11 @@ class DatasetRetrieval:
# Batch update hit_count for all segments
if segment_ids_to_update:
session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
session.execute(
update(DocumentSegment)
.where(DocumentSegment.id.in_(segment_ids_to_update))
.values(hit_count=DocumentSegment.hit_count + 1)
.execution_options(synchronize_session=False)
)
self._send_trace_task(message_id, documents, timer)
@ -1822,7 +1824,7 @@ class DatasetRetrieval:
def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]:
with session_factory.create_session() as session:
subquery = (
session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
select(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
.where(
DocumentModel.indexing_status == "completed",
DocumentModel.enabled == True,
@ -1834,13 +1836,12 @@ class DatasetRetrieval:
.subquery()
)
results = (
session.query(Dataset)
results = session.scalars(
select(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
).all()
available_datasets = []
for dataset in results:

View File

@ -1,6 +1,8 @@
import concurrent.futures
import logging
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
@ -21,7 +23,7 @@ class SummaryIndex:
) -> None:
if is_preview:
with session_factory.create_session() as session:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return
@ -34,32 +36,31 @@ class SummaryIndex:
if not document_id:
return
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
# Skip qa_model documents
if document is None or document.doc_form == "qa_model":
return
query = session.query(DocumentSegment).filter_by(
dataset_id=dataset_id,
document_id=document_id,
status="completed",
enabled=True,
)
segments = query.all()
segments = session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
).all()
segment_ids = [segment.id for segment in segments]
if not segment_ids:
return
existing_summaries = (
session.query(DocumentSegmentSummary)
.filter(
existing_summaries = session.scalars(
select(DocumentSegmentSummary).where(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.status == "completed",
)
.all()
)
).all()
completed_summary_segment_ids = {i.chunk_id for i in existing_summaries}
# Preview mode should process segments that are MISSING completed summaries
pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids]
@ -73,7 +74,7 @@ class SummaryIndex:
def process_segment(segment_id: str) -> None:
"""Process a single segment in a thread with a fresh DB session."""
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).filter_by(id=segment_id).first()
segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1))
if segment is None:
return
try:

View File

@ -450,6 +450,12 @@ class WorkflowToolParameterConfiguration(BaseModel):
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
class ToolInvokeMetaDict(TypedDict):
time_cost: float
error: str | None
tool_config: dict[str, Any] | None
class ToolInvokeMeta(BaseModel):
"""
Tool invoke meta
@ -473,12 +479,13 @@ class ToolInvokeMeta(BaseModel):
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self):
return {
def to_dict(self) -> ToolInvokeMetaDict:
result: ToolInvokeMetaDict = {
"time_cost": self.time_cost,
"error": self.error,
"tool_config": self.tool_config,
}
return result
class ToolLabel(BaseModel):

View File

@ -262,6 +262,8 @@ class ToolEngine:
ensure_ascii=False,
)
)
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
continue
else:
parts.append(str(response.message))

View File

@ -682,7 +682,7 @@ class ToolManager:
with Session(db.engine, autoflush=False) as session:
ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
return list(session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids))))
@classmethod
def list_providers_from_api(
@ -993,7 +993,7 @@ class ToolManager:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str:
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | str:
try:
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
@ -1001,7 +1001,7 @@ class ToolManager:
mcp_provider = mcp_service.get_provider_entity(
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
)
return mcp_provider.provider_icon
return cast(EmojiIconDict | str, mcp_provider.provider_icon)
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
except Exception:
@ -1013,7 +1013,7 @@ class ToolManager:
tenant_id: str,
provider_type: ToolProviderType,
provider_id: str,
) -> str | EmojiIconDict | dict[str, str]:
) -> str | EmojiIconDict:
"""
get the tool icon

View File

@ -17,10 +17,8 @@ class WorkflowToolConfigurationUtils:
"""
nodes = graph.get("nodes", [])
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
if not start_node:
return []
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
@classmethod

View File

@ -14,6 +14,7 @@ from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
from redis.retry import Retry
from redis.sentinel import Sentinel
from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
@ -126,6 +127,35 @@ redis_client: RedisClientWrapper = RedisClientWrapper()
_pubsub_redis_client: redis.Redis | RedisCluster | None = None
class RedisSSLParamsDict(TypedDict):
ssl_cert_reqs: int
ssl_ca_certs: str | None
ssl_certfile: str | None
ssl_keyfile: str | None
class RedisHealthParamsDict(TypedDict):
retry: Retry
socket_timeout: float | None
socket_connect_timeout: float | None
health_check_interval: int | None
class RedisBaseParamsDict(TypedDict):
username: str | None
password: str | None
db: int
encoding: str
encoding_errors: str
decode_responses: bool
protocol: int
cache_config: CacheConfig | None
retry: Retry
socket_timeout: float | None
socket_connect_timeout: float | None
health_check_interval: int | None
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
"""Get SSL configuration for Redis connection."""
if not dify_config.REDIS_USE_SSL:
@ -171,14 +201,14 @@ def _get_retry_policy() -> Retry:
)
def _get_connection_health_params() -> dict[str, Any]:
def _get_connection_health_params() -> RedisHealthParamsDict:
"""Get connection health and retry parameters for standalone and Sentinel Redis clients."""
return {
"retry": _get_retry_policy(),
"socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT,
"socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
"health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL,
}
return RedisHealthParamsDict(
retry=_get_retry_policy(),
socket_timeout=dify_config.REDIS_SOCKET_TIMEOUT,
socket_connect_timeout=dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
health_check_interval=dify_config.REDIS_HEALTH_CHECK_INTERVAL,
)
def _get_cluster_connection_health_params() -> dict[str, Any]:
@ -189,26 +219,26 @@ def _get_cluster_connection_health_params() -> dict[str, Any]:
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
are passed through.
"""
params = _get_connection_health_params()
params: dict[str, Any] = dict(_get_connection_health_params())
return {k: v for k, v in params.items() if k != "health_check_interval"}
def _get_base_redis_params() -> dict[str, Any]:
def _get_base_redis_params() -> RedisBaseParamsDict:
"""Get base Redis connection parameters including retry and health policy."""
return {
"username": dify_config.REDIS_USERNAME,
"password": dify_config.REDIS_PASSWORD or None,
"db": dify_config.REDIS_DB,
"encoding": "utf-8",
"encoding_errors": "strict",
"decode_responses": False,
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
"cache_config": _get_cache_configuration(),
return RedisBaseParamsDict(
username=dify_config.REDIS_USERNAME,
password=dify_config.REDIS_PASSWORD or None,
db=dify_config.REDIS_DB,
encoding="utf-8",
encoding_errors="strict",
decode_responses=False,
protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
cache_config=_get_cache_configuration(),
**_get_connection_health_params(),
}
)
def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
def _create_sentinel_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
"""Create Redis client using Sentinel configuration."""
if not dify_config.REDIS_SENTINELS:
raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True")
@ -232,7 +262,8 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis,
sentinel_kwargs=sentinel_kwargs,
)
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
params: dict[str, Any] = {**redis_params}
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **params)
return master
@ -259,18 +290,16 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
return cluster
def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
def _create_standalone_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
"""Create standalone Redis client."""
connection_class, ssl_kwargs = _get_ssl_configuration()
params = {**redis_params}
params.update(
{
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
"connection_class": connection_class,
}
)
params: dict[str, Any] = {
**redis_params,
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
"connection_class": connection_class,
}
if dify_config.REDIS_MAX_CONNECTIONS:
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
@ -293,8 +322,8 @@ def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis |
kwargs["max_connections"] = max_conns
return RedisCluster.from_url(pubsub_url, **kwargs)
health_params = _get_connection_health_params()
kwargs = {**health_params}
standalone_health_params: dict[str, Any] = dict(_get_connection_health_params())
kwargs = {**standalone_health_params}
if max_conns:
kwargs["max_connections"] = max_conns
return redis.Redis.from_url(pubsub_url, **kwargs)

View File

@ -14,9 +14,15 @@ from __future__ import annotations
import logging
import threading
from typing import Any
from typing import TYPE_CHECKING, Any
import redis
from redis.cluster import RedisCluster
from redis.exceptions import LockNotOwnedError, RedisError
from redis.lock import Lock
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
logger = logging.getLogger(__name__)
@ -38,21 +44,21 @@ class DbMigrationAutoRenewLock:
primary error/exit code.
"""
_redis_client: Any
_redis_client: redis.Redis | RedisCluster | RedisClientWrapper
_name: str
_ttl_seconds: float
_renew_interval_seconds: float
_log_context: str | None
_logger: logging.Logger
_lock: Any
_lock: Lock | None
_stop_event: threading.Event | None
_thread: threading.Thread | None
_acquired: bool
def __init__(
self,
redis_client: Any,
redis_client: redis.Redis | RedisCluster | RedisClientWrapper,
name: str,
ttl_seconds: float = 60,
renew_interval_seconds: float | None = None,
@ -127,7 +133,7 @@ class DbMigrationAutoRenewLock:
)
self._thread.start()
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
def _heartbeat_loop(self, lock: Lock, stop_event: threading.Event) -> None:
while not stop_event.wait(self._renew_interval_seconds):
try:
lock.reacquire()

View File

@ -24,6 +24,8 @@ class TypeBase(MappedAsDataclass, DeclarativeBase):
class DefaultFieldsMixin:
"""Mixin for models that inherit from Base (non-dataclass)."""
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
@ -53,6 +55,42 @@ class DefaultFieldsMixin:
return f"<{self.__class__.__name__}(id={self.id})>"
class DefaultFieldsDCMixin(MappedAsDataclass):
"""Mixin for models that inherit from TypeBase (MappedAsDataclass)."""
__abstract__ = True
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuidv7()),
default_factory=lambda: str(uuidv7()),
init=False,
)
created_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
insert_default=naive_utc_now,
default_factory=naive_utc_now,
init=False,
server_default=func.current_timestamp(),
)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
insert_default=naive_utc_now,
default_factory=naive_utc_now,
init=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(id={self.id})>"
def gen_uuidv4_string() -> str:
"""gen_uuidv4_string generate a UUIDv4 string.

View File

@ -108,6 +108,56 @@ class ExternalKnowledgeApiDict(TypedDict):
created_at: str
class DocumentDict(TypedDict):
id: str
tenant_id: str
dataset_id: str
position: int
data_source_type: str
data_source_info: str | None
dataset_process_rule_id: str | None
batch: str
name: str
created_from: str
created_by: str
created_api_request_id: str | None
created_at: datetime
processing_started_at: datetime | None
file_id: str | None
word_count: int | None
parsing_completed_at: datetime | None
cleaning_completed_at: datetime | None
splitting_completed_at: datetime | None
tokens: int | None
indexing_latency: float | None
completed_at: datetime | None
is_paused: bool | None
paused_by: str | None
paused_at: datetime | None
error: str | None
stopped_at: datetime | None
indexing_status: str
enabled: bool
disabled_at: datetime | None
disabled_by: str | None
archived: bool
archived_reason: str | None
archived_by: str | None
archived_at: datetime | None
updated_at: datetime
doc_type: str | None
doc_metadata: Any
doc_form: IndexStructureType
doc_language: str | None
display_status: str | None
data_source_info_dict: dict[str, Any]
average_segment_length: int
dataset_process_rule: ProcessRuleDict | None
dataset: None
segment_count: int | None
hit_count: int | None
class DatasetPermissionEnum(enum.StrEnum):
ONLY_ME = "only_me"
ALL_TEAM = "all_team_members"
@ -675,8 +725,8 @@ class Document(Base):
)
return built_in_fields
def to_dict(self) -> dict[str, Any]:
return {
def to_dict(self) -> DocumentDict:
result: DocumentDict = {
"id": self.id,
"tenant_id": self.tenant_id,
"dataset_id": self.dataset_id,
@ -721,10 +771,11 @@ class Document(Base):
"data_source_info_dict": self.data_source_info_dict,
"average_segment_length": self.average_segment_length,
"dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
"dataset": None, # Dataset class doesn't have a to_dict method
"dataset": None,
"segment_count": self.segment_count,
"hit_count": self.hit_count,
}
return result
@classmethod
def from_dict(cls, data: dict[str, Any]):

View File

@ -809,11 +809,11 @@ class AccountService:
rest of the system gradually normalizes new inputs.
"""
with session_factory.create_session() as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none()
if account or email == email.lower():
return account
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none()
@classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:

View File

@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor
import click
from flask import Flask, current_app
from graphon.model_runtime.utils.encoders import jsonable_encoder
from sqlalchemy import select
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@ -62,13 +62,11 @@ class ClearFreePlanTenantExpiredLogs:
for model, table_name in related_tables:
# Query records related to expired messages
records = (
session.query(model)
.where(
records = session.scalars(
select(model).where(
model.message_id.in_(batch_message_ids), # type: ignore
)
.all()
)
).all()
if len(records) == 0:
continue
@ -103,9 +101,13 @@ class ClearFreePlanTenantExpiredLogs:
except Exception:
logger.exception("Failed to save %s records", table_name)
session.query(model).where(
model.id.in_(record_ids), # type: ignore
).delete(synchronize_session=False)
session.execute(
delete(model)
.where(
model.id.in_(record_ids), # type: ignore
)
.execution_options(synchronize_session=False)
)
click.echo(
click.style(
@ -121,15 +123,14 @@ class ClearFreePlanTenantExpiredLogs:
app_ids = [app.id for app in apps]
while True:
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
messages = (
session.query(Message)
messages = session.scalars(
select(Message)
.where(
Message.app_id.in_(app_ids),
Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
).all()
if len(messages) == 0:
break
@ -147,9 +148,9 @@ class ClearFreePlanTenantExpiredLogs:
message_ids = [message.id for message in messages]
# delete messages
session.query(Message).where(
Message.id.in_(message_ids),
).delete(synchronize_session=False)
session.execute(
delete(Message).where(Message.id.in_(message_ids)).execution_options(synchronize_session=False)
)
cls._clear_message_related_tables(session, tenant_id, message_ids)
@ -161,15 +162,14 @@ class ClearFreePlanTenantExpiredLogs:
while True:
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
conversations = (
session.query(Conversation)
conversations = session.scalars(
select(Conversation)
.where(
Conversation.app_id.in_(app_ids),
Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
).all()
if len(conversations) == 0:
break
@ -186,9 +186,11 @@ class ClearFreePlanTenantExpiredLogs:
)
conversation_ids = [conversation.id for conversation in conversations]
session.query(Conversation).where(
Conversation.id.in_(conversation_ids),
).delete(synchronize_session=False)
session.execute(
delete(Conversation)
.where(Conversation.id.in_(conversation_ids))
.execution_options(synchronize_session=False)
)
click.echo(
click.style(
@ -293,15 +295,14 @@ class ClearFreePlanTenantExpiredLogs:
while True:
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
workflow_app_logs = (
session.query(WorkflowAppLog)
workflow_app_logs = session.scalars(
select(WorkflowAppLog)
.where(
WorkflowAppLog.tenant_id == tenant_id,
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
).all()
if len(workflow_app_logs) == 0:
break
@ -321,8 +322,10 @@ class ClearFreePlanTenantExpiredLogs:
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
# delete workflow app logs
session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
synchronize_session=False
session.execute(
delete(WorkflowAppLog)
.where(WorkflowAppLog.id.in_(workflow_app_log_ids))
.execution_options(synchronize_session=False)
)
click.echo(
@ -344,7 +347,7 @@ class ClearFreePlanTenantExpiredLogs:
current_time = started_at
with sessionmaker(db.engine).begin() as session:
total_tenant_count = session.query(Tenant.id).count()
total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
@ -409,9 +412,12 @@ class ClearFreePlanTenantExpiredLogs:
tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
.where(Tenant.created_at.between(current_time, current_time + test_interval))
.count()
session.scalar(
select(func.count(Tenant.id)).where(
Tenant.created_at.between(current_time, current_time + test_interval)
)
)
or 0
)
if tenant_count <= 100:
interval = test_interval
@ -433,8 +439,8 @@ class ClearFreePlanTenantExpiredLogs:
batch_end = min(current_time + interval, ended_at)
rs = (
session.query(Tenant.id)
rs = session.execute(
select(Tenant.id)
.where(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at)
)

View File

@ -552,8 +552,8 @@ class DatasetService:
external_knowledge_api_id: External knowledge API identifier
"""
with sessionmaker(db.engine).begin() as session:
external_knowledge_binding = (
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
external_knowledge_binding = session.scalar(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == dataset_id).limit(1)
)
if not external_knowledge_binding:
@ -1454,15 +1454,17 @@ class DocumentService:
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
with session_factory.create_session() as session:
updated_count = (
session.query(Document)
.filter(
result = session.execute(
update(Document)
.where(
Document.id.in_(document_id_list),
Document.dataset_id == dataset_id,
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.update({Document.need_summary: need_summary}, synchronize_session=False)
.values(need_summary=need_summary)
.execution_options(synchronize_session=False)
)
updated_count = result.rowcount # type: ignore[union-attr,attr-defined]
session.commit()
logger.info(
"Updated need_summary to %s for %d documents in dataset %s",
@ -2822,6 +2824,10 @@ class DocumentService:
knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values())
if knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL:
if not knowledge_config.process_rule.rules.parent_mode:
knowledge_config.process_rule.rules.parent_mode = "paragraph"
if not knowledge_config.process_rule.rules.segmentation:
raise ValueError("Process rule segmentation is required")

View File

@ -4,7 +4,7 @@ from collections.abc import Mapping
from typing import Any
from graphon.model_runtime.entities.provider_entities import FormType
from sqlalchemy import func, select
from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@ -54,11 +54,13 @@ class DatasourceProviderService:
remove oauth custom client params
"""
with sessionmaker(bind=db.engine).begin() as session:
session.query(DatasourceOauthTenantParamConfig).filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
).delete()
session.execute(
delete(DatasourceOauthTenantParamConfig).where(
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
)
)
def decrypt_datasource_provider_credentials(
self,
@ -110,15 +112,21 @@ class DatasourceProviderService:
"""
with sessionmaker(bind=db.engine).begin() as session:
if credential_id:
datasource_provider = (
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
datasource_provider = session.scalar(
select(DatasourceProvider)
.where(DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.id == credential_id)
.limit(1)
)
else:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
datasource_provider = session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
.limit(1)
)
if not datasource_provider:
return {}
@ -173,12 +181,15 @@ class DatasourceProviderService:
get all datasource credentials by provider
"""
with sessionmaker(bind=db.engine).begin() as session:
datasource_providers = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
datasource_providers = session.scalars(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.all()
)
).all()
if not datasource_providers:
return []
current_user = get_current_user()
@ -232,15 +243,15 @@ class DatasourceProviderService:
update datasource provider name
"""
with sessionmaker(bind=db.engine).begin() as session:
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
target_provider = session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.id == credential_id,
DatasourceProvider.provider == datasource_provider_id.provider_name,
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
)
.first()
.limit(1)
)
if target_provider is None:
raise ValueError("provider not found")
@ -250,16 +261,16 @@ class DatasourceProviderService:
# check name is exist
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=name,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.name == name,
DatasourceProvider.provider == datasource_provider_id.provider_name,
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
)
)
.count()
> 0
):
or 0
) > 0:
raise ValueError("Authorization name is already exists")
target_provider.name = name
@ -273,26 +284,31 @@ class DatasourceProviderService:
"""
with sessionmaker(bind=db.engine).begin() as session:
# get provider
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
target_provider = session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.id == credential_id,
DatasourceProvider.provider == datasource_provider_id.provider_name,
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
)
.first()
.limit(1)
)
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=target_provider.provider,
plugin_id=target_provider.plugin_id,
is_default=True,
).update({"is_default": False})
session.execute(
update(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == target_provider.provider,
DatasourceProvider.plugin_id == target_provider.plugin_id,
DatasourceProvider.is_default.is_(True),
)
.values(is_default=False)
.execution_options(synchronize_session=False)
)
# set new default provider
target_provider.is_default = True
@ -311,14 +327,14 @@ class DatasourceProviderService:
if client_params is None and enabled is None:
return
with sessionmaker(bind=db.engine).begin() as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
tenant_oauth_client_params = session.scalar(
select(DatasourceOauthTenantParamConfig)
.where(
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
)
.first()
.limit(1)
)
if not tenant_oauth_client_params:
@ -351,9 +367,14 @@ class DatasourceProviderService:
"""
with Session(db.engine).no_autoflush as session:
return (
session.query(DatasourceOauthParamConfig)
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
.first()
session.scalar(
select(DatasourceOauthParamConfig)
.where(
DatasourceOauthParamConfig.provider == datasource_provider_id.provider_name,
DatasourceOauthParamConfig.plugin_id == datasource_provider_id.plugin_id,
)
.limit(1)
)
is not None
)
@ -423,15 +444,15 @@ class DatasourceProviderService:
plugin_id = datasource_provider_id.plugin_id
with Session(db.engine).no_autoflush as session:
# get tenant oauth client params
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
enabled=True,
tenant_oauth_client_params = session.scalar(
select(DatasourceOauthTenantParamConfig)
.where(
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
DatasourceOauthTenantParamConfig.provider == provider,
DatasourceOauthTenantParamConfig.plugin_id == plugin_id,
DatasourceOauthTenantParamConfig.enabled.is_(True),
)
.first()
.limit(1)
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
@ -443,8 +464,13 @@ class DatasourceProviderService:
is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
if is_verified:
# fallback to system oauth client params
oauth_client_params = (
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
oauth_client_params = session.scalar(
select(DatasourceOauthParamConfig)
.where(
DatasourceOauthParamConfig.provider == provider,
DatasourceOauthParamConfig.plugin_id == plugin_id,
)
.limit(1)
)
if oauth_client_params:
return oauth_client_params.system_credentials
@ -455,15 +481,13 @@ class DatasourceProviderService:
def generate_next_datasource_provider_name(
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
) -> str:
db_providers = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
db_providers = session.scalars(
select(DatasourceProvider).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
)
.all()
)
).all()
return generate_incremental_name(
[provider.name for provider in db_providers],
f"{credential_type.get_name()}",
@ -485,8 +509,10 @@ class DatasourceProviderService:
with sessionmaker(bind=db.engine).begin() as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
with redis_client.lock(lock, timeout=20):
target_provider = (
session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
target_provider = session.scalar(
select(DatasourceProvider)
.where(DatasourceProvider.id == credential_id, DatasourceProvider.tenant_id == tenant_id)
.limit(1)
)
if target_provider is None:
raise ValueError("provider not found")
@ -496,25 +522,28 @@ class DatasourceProviderService:
db_provider_name = target_provider.name
else:
name_conflict = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=CredentialType.OAUTH2.value,
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.name == db_provider_name,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
DatasourceProvider.auth_type == CredentialType.OAUTH2.value,
)
)
.count()
or 0
)
if name_conflict > 0:
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
for provider in session.scalars(
select(DatasourceProvider).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
)
).all()
],
db_provider_name,
)
@ -556,25 +585,27 @@ class DatasourceProviderService:
)
else:
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.name == db_provider_name,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
DatasourceProvider.auth_type == credential_type.value,
)
)
.count()
> 0
):
or 0
) > 0:
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
for provider in session.scalars(
select(DatasourceProvider).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
)
).all()
],
db_provider_name,
)
@ -627,11 +658,16 @@ class DatasourceProviderService:
# check name is exist
if (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
.count()
> 0
):
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.plugin_id == plugin_id,
DatasourceProvider.provider == provider_name,
DatasourceProvider.name == db_provider_name,
)
)
or 0
) > 0:
raise ValueError("Authorization name is already exists")
try:
@ -918,21 +954,31 @@ class DatasourceProviderService:
"""
with sessionmaker(bind=db.engine).begin() as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
datasource_provider = session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.id == auth_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.limit(1)
)
if not datasource_provider:
raise ValueError("Datasource provider not found")
# update name
if name and name != datasource_provider.name:
if (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
.count()
> 0
):
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.name == name,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
)
or 0
) > 0:
raise ValueError("Authorization name is already exists")
datasource_provider.name = name

View File

@ -13,6 +13,7 @@ import sqlalchemy as sa
import tqdm
from flask import Flask, current_app
from pydantic import TypeAdapter
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
@ -66,7 +67,7 @@ class PluginMigration:
current_time = started_at
with Session(db.engine) as session:
total_tenant_count = session.query(Tenant.id).count()
total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
@ -123,9 +124,12 @@ class PluginMigration:
tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
.where(Tenant.created_at.between(current_time, current_time + test_interval))
.count()
session.scalar(
select(func.count(Tenant.id)).where(
Tenant.created_at.between(current_time, current_time + test_interval)
)
)
or 0
)
if tenant_count <= 100:
interval = test_interval
@ -147,8 +151,8 @@ class PluginMigration:
batch_end = min(current_time + interval, ended_at)
rs = (
session.query(Tenant.id)
rs = session.execute(
select(Tenant.id)
.where(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at)
)
@ -235,7 +239,7 @@ class PluginMigration:
Extract tool tables.
"""
with Session(db.engine) as session:
rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all()
rs = session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id)).all()
result = []
for row in rs:
result.append(ToolProviderID(row.provider).plugin_id)
@ -249,7 +253,7 @@ class PluginMigration:
"""
with Session(db.engine) as session:
rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all()
rs = session.scalars(select(Workflow).where(Workflow.tenant_id == tenant_id)).all()
result = []
for row in rs:
graph = row.graph_dict
@ -272,7 +276,7 @@ class PluginMigration:
Extract app tables.
"""
with Session(db.engine) as session:
apps = session.query(App).where(App.tenant_id == tenant_id).all()
apps = session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
if not apps:
return []
@ -280,7 +284,7 @@ class PluginMigration:
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT
]
rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
rs = session.scalars(select(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids))).all()
result = []
for row in rs:
agent_config = row.agent_mode_dict

View File

@ -283,7 +283,9 @@ class RagPipelineDslService:
):
raise ValueError("Chunk structure is not compatible with the published pipeline")
if not dataset:
datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
datasets = self._session.scalars(
select(Dataset).where(Dataset.tenant_id == account.current_tenant_id)
).all()
names = [dataset.name for dataset in datasets]
generate_name = generate_incremental_name(names, name)
dataset = Dataset(
@ -303,8 +305,8 @@ class RagPipelineDslService:
chunk_structure=knowledge_configuration.chunk_structure,
)
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
dataset_collection_binding = self._session.scalar(
select(DatasetCollectionBinding)
.where(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
@ -312,7 +314,7 @@ class RagPipelineDslService:
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
)
.order_by(DatasetCollectionBinding.created_at)
.first()
.limit(1)
)
if not dataset_collection_binding:
@ -440,8 +442,8 @@ class RagPipelineDslService:
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
dataset_collection_binding = self._session.scalar(
select(DatasetCollectionBinding)
.where(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
@ -449,7 +451,7 @@ class RagPipelineDslService:
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
)
.order_by(DatasetCollectionBinding.created_at)
.first()
.limit(1)
)
if not dataset_collection_binding:
@ -591,14 +593,14 @@ class RagPipelineDslService:
IMPORT_INFO_REDIS_EXPIRY,
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
)
workflow = (
self._session.query(Workflow)
workflow = self._session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
.limit(1)
)
# create draft workflow if not found
@ -665,14 +667,12 @@ class RagPipelineDslService:
:param pipeline: Pipeline instance
"""
workflow = (
self._session.query(Workflow)
.where(
workflow = self._session.scalar(
select(Workflow).where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
)
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")
@ -904,15 +904,16 @@ class RagPipelineDslService:
):
if rag_pipeline_dataset_create_entity.name:
# check if dataset name already exists
if (
self._session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
if self._session.scalar(
select(Dataset).where(
Dataset.name == rag_pipeline_dataset_create_entity.name,
Dataset.tenant_id == tenant_id,
)
):
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
else:
# generate a random name as Untitled 1 2 3 ...
datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all()
datasets = self._session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all()
names = [dataset.name for dataset in datasets]
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
names,

View File

@ -8,6 +8,7 @@ from typing import TypedDict, cast
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
@ -109,8 +110,13 @@ class SummaryIndexService:
"""
with session_factory.create_session() as session:
# Check if summary record already exists
existing_summary = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
existing_summary = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if existing_summary:
@ -309,8 +315,10 @@ class SummaryIndexService:
summary_record_id,
segment.id,
)
summary_record_in_session = (
session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
summary_record_in_session = session.scalar(
select(DocumentSegmentSummary)
.where(DocumentSegmentSummary.id == summary_record_id)
.limit(1)
)
if not summary_record_in_session:
@ -323,10 +331,13 @@ class SummaryIndexService:
dataset.id,
segment.id,
)
summary_record_in_session = (
session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
.first()
summary_record_in_session = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if not summary_record_in_session:
@ -487,8 +498,10 @@ class SummaryIndexService:
with session_factory.create_session() as error_session:
# Try to find the record by id first
# Note: Using assignment only (no type annotation) to avoid redeclaration error
summary_record_in_session = (
error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
summary_record_in_session = error_session.scalar(
select(DocumentSegmentSummary)
.where(DocumentSegmentSummary.id == summary_record_id)
.limit(1)
)
if not summary_record_in_session:
# Try to find by chunk_id and dataset_id
@ -500,10 +513,13 @@ class SummaryIndexService:
dataset.id,
segment.id,
)
summary_record_in_session = (
error_session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
.first()
summary_record_in_session = error_session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record_in_session:
@ -551,14 +567,12 @@ class SummaryIndexService:
with session_factory.create_session() as session:
# Query existing summary records
existing_summaries = (
session.query(DocumentSegmentSummary)
.filter(
existing_summaries = session.scalars(
select(DocumentSegmentSummary).where(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset.id,
)
.all()
)
).all()
existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries}
# Create or update records
@ -603,8 +617,13 @@ class SummaryIndexService:
error: Error message
"""
with session_factory.create_session() as session:
summary_record = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record:
@ -639,8 +658,13 @@ class SummaryIndexService:
with session_factory.create_session() as session:
try:
# Get or refresh summary record in this session
summary_record_in_session = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record_in_session = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if not summary_record_in_session:
@ -710,8 +734,13 @@ class SummaryIndexService:
except Exception as e:
logger.exception("Failed to generate summary for segment %s", segment.id)
# Update summary record with error status
summary_record_in_session = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record_in_session = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record_in_session:
summary_record_in_session.status = SummaryStatus.ERROR
@ -769,17 +798,17 @@ class SummaryIndexService:
with session_factory.create_session() as session:
# Query segments (only enabled segments)
query = session.query(DocumentSegment).filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True, # Only generate summaries for enabled segments
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled.is_(True), # Only generate summaries for enabled segments
)
if segment_ids:
query = query.filter(DocumentSegment.id.in_(segment_ids))
stmt = stmt.where(DocumentSegment.id.in_(segment_ids))
segments = query.all()
segments = list(session.scalars(stmt).all())
if not segments:
logger.info("No segments found for document %s", document.id)
@ -848,15 +877,15 @@ class SummaryIndexService:
from libs.datetime_utils import naive_utc_now
with session_factory.create_session() as session:
query = session.query(DocumentSegmentSummary).filter_by(
dataset_id=dataset.id,
enabled=True, # Only disable enabled summaries
stmt = select(DocumentSegmentSummary).where(
DocumentSegmentSummary.dataset_id == dataset.id,
DocumentSegmentSummary.enabled.is_(True), # Only disable enabled summaries
)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
summaries = query.all()
summaries = session.scalars(stmt).all()
if not summaries:
return
@ -911,15 +940,15 @@ class SummaryIndexService:
return
with session_factory.create_session() as session:
query = session.query(DocumentSegmentSummary).filter_by(
dataset_id=dataset.id,
enabled=False, # Only enable disabled summaries
stmt = select(DocumentSegmentSummary).where(
DocumentSegmentSummary.dataset_id == dataset.id,
DocumentSegmentSummary.enabled.is_(False), # Only enable disabled summaries
)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
summaries = query.all()
summaries = session.scalars(stmt).all()
if not summaries:
return
@ -935,13 +964,13 @@ class SummaryIndexService:
enabled_count = 0
for summary in summaries:
# Get the original segment
segment = (
session.query(DocumentSegment)
.filter_by(
id=summary.chunk_id,
dataset_id=dataset.id,
segment = session.scalar(
select(DocumentSegment)
.where(
DocumentSegment.id == summary.chunk_id,
DocumentSegment.dataset_id == dataset.id,
)
.first()
.limit(1)
)
# Summary.enabled stays in sync with chunk.enabled,
@ -988,12 +1017,12 @@ class SummaryIndexService:
segment_ids: List of segment IDs to delete summaries for. If None, delete all.
"""
with session_factory.create_session() as session:
query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id)
stmt = select(DocumentSegmentSummary).where(DocumentSegmentSummary.dataset_id == dataset.id)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
summaries = query.all()
summaries = session.scalars(stmt).all()
if not summaries:
return
@ -1046,10 +1075,13 @@ class SummaryIndexService:
# Check if summary_content is empty (whitespace-only strings are considered empty)
if not summary_content or not summary_content.strip():
# If summary is empty, only delete existing summary vector and record
summary_record = (
session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
.first()
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record:
@ -1077,8 +1109,13 @@ class SummaryIndexService:
return None
# Find existing summary record
summary_record = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record:
@ -1162,8 +1199,13 @@ class SummaryIndexService:
except Exception as e:
logger.exception("Failed to update summary for segment %s", segment.id)
# Update summary record with error status if it exists
summary_record = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record:
summary_record.status = SummaryStatus.ERROR
@ -1185,14 +1227,14 @@ class SummaryIndexService:
DocumentSegmentSummary instance if found, None otherwise
"""
with session_factory.create_session() as session:
return (
session.query(DocumentSegmentSummary)
return session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment_id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
)
.first()
.limit(1)
)
@staticmethod
@ -1211,15 +1253,13 @@ class SummaryIndexService:
return {}
with session_factory.create_session() as session:
summary_records = (
session.query(DocumentSegmentSummary)
.where(
summary_records = session.scalars(
select(DocumentSegmentSummary).where(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
)
.all()
)
).all()
return {summary.chunk_id: summary for summary in summary_records}
@ -1239,16 +1279,16 @@ class SummaryIndexService:
List of DocumentSegmentSummary instances (only enabled summaries)
"""
with session_factory.create_session() as session:
query = session.query(DocumentSegmentSummary).filter(
stmt = select(DocumentSegmentSummary).where(
DocumentSegmentSummary.document_id == document_id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
return query.all()
return list(session.scalars(stmt).all())
@staticmethod
def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None:
@ -1265,16 +1305,15 @@ class SummaryIndexService:
"""
# Get all segments for this document (excluding qa_model and re_segment)
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment.id)
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.status != "re_segment",
DocumentSegment.tenant_id == tenant_id,
)
.all()
segment_ids = list(
session.scalars(
select(DocumentSegment.id).where(
DocumentSegment.document_id == document_id,
DocumentSegment.status != "re_segment",
DocumentSegment.tenant_id == tenant_id,
)
).all()
)
segment_ids = [seg.id for seg in segments]
if not segment_ids:
return None
@ -1312,15 +1351,13 @@ class SummaryIndexService:
# Get all segments for these documents (excluding qa_model and re_segment)
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment.id, DocumentSegment.document_id)
.where(
segments = session.execute(
select(DocumentSegment.id, DocumentSegment.document_id).where(
DocumentSegment.document_id.in_(document_ids),
DocumentSegment.status != "re_segment",
DocumentSegment.tenant_id == tenant_id,
)
.all()
)
).all()
# Group segments by document_id
document_segments_map: dict[str, list[str]] = {}

View File

@ -4,7 +4,7 @@ from collections.abc import Mapping
from pathlib import Path
from typing import Any
from sqlalchemy import exists, select
from sqlalchemy import delete, exists, func, select, update
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@ -47,11 +47,15 @@ class BuiltinToolManageService:
"""
tool_provider = ToolProviderID(provider)
with sessionmaker(bind=db.engine).begin() as session:
session.query(ToolOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
).delete()
session.execute(
delete(ToolOAuthTenantClient)
.where(
ToolOAuthTenantClient.tenant_id == tenant_id,
ToolOAuthTenantClient.provider == tool_provider.provider_name,
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
)
.execution_options(synchronize_session=False)
)
return {"result": "success"}
@staticmethod
@ -151,13 +155,13 @@ class BuiltinToolManageService:
"""
with sessionmaker(bind=db.engine).begin() as session:
# get if the provider exists
db_provider = (
session.query(BuiltinToolProvider)
db_provider = session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
.limit(1)
)
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
@ -228,7 +232,13 @@ class BuiltinToolManageService:
raise ValueError(f"provider {provider} does not need credentials")
provider_count = (
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
session.scalar(
select(func.count(BuiltinToolProvider.id)).where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
)
or 0
)
# check if the provider count is reached the limit
@ -304,16 +314,15 @@ class BuiltinToolManageService:
def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type,
db_providers = session.scalars(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.credential_type == credential_type,
)
.order_by(BuiltinToolProvider.created_at.desc())
.all()
)
).all()
return generate_incremental_name(
[provider.name for provider in db_providers],
f"{credential_type.get_name()}",
@ -375,13 +384,13 @@ class BuiltinToolManageService:
delete tool provider
"""
with sessionmaker(bind=db.engine).begin() as session:
db_provider = (
session.query(BuiltinToolProvider)
db_provider = session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
.limit(1)
)
if db_provider is None:
@ -405,14 +414,26 @@ class BuiltinToolManageService:
"""
with sessionmaker(bind=db.engine).begin() as session:
# get provider
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
target_provider = session.scalar(
select(BuiltinToolProvider)
.where(BuiltinToolProvider.id == id, BuiltinToolProvider.tenant_id == tenant_id)
.limit(1)
)
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
).update({"is_default": False})
session.execute(
update(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.user_id == user_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.is_default.is_(True),
)
.values(is_default=False)
.execution_options(synchronize_session=False)
)
# set new default provider
target_provider.is_default = True
@ -426,10 +447,13 @@ class BuiltinToolManageService:
"""
tool_provider = ToolProviderID(provider_name)
with Session(db.engine, autoflush=False) as session:
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
system_client = session.scalar(
select(ToolOAuthSystemClient)
.where(
ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id,
ToolOAuthSystemClient.provider == tool_provider.provider_name,
)
.limit(1)
)
return system_client is not None
@ -440,15 +464,15 @@ class BuiltinToolManageService:
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
user_client = session.scalar(
select(ToolOAuthTenantClient)
.where(
ToolOAuthTenantClient.tenant_id == tenant_id,
ToolOAuthTenantClient.provider == tool_provider.provider_name,
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
ToolOAuthTenantClient.enabled.is_(True),
)
.first()
.limit(1)
)
return user_client is not None and user_client.enabled
@ -465,15 +489,15 @@ class BuiltinToolManageService:
cache=NoOpProviderCredentialCache(),
)
with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
user_client = session.scalar(
select(ToolOAuthTenantClient)
.where(
ToolOAuthTenantClient.tenant_id == tenant_id,
ToolOAuthTenantClient.provider == tool_provider.provider_name,
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
ToolOAuthTenantClient.enabled.is_(True),
)
.first()
.limit(1)
)
oauth_params: Mapping[str, Any] | None = None
if user_client:
@ -487,10 +511,13 @@ class BuiltinToolManageService:
if not is_verified:
return oauth_params
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
system_client = session.scalar(
select(ToolOAuthSystemClient)
.where(
ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id,
ToolOAuthSystemClient.provider == tool_provider.provider_name,
)
.limit(1)
)
if system_client:
try:
@ -582,8 +609,8 @@ class BuiltinToolManageService:
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider = (
session.query(BuiltinToolProvider)
provider = session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
@ -592,11 +619,11 @@ class BuiltinToolManageService:
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
.limit(1)
)
else:
provider = (
session.query(BuiltinToolProvider)
provider = session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
@ -606,7 +633,7 @@ class BuiltinToolManageService:
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
.limit(1)
)
if provider is None:
@ -616,14 +643,14 @@ class BuiltinToolManageService:
return provider
except Exception:
# it's an old provider without organization
return (
session.query(BuiltinToolProvider)
return session.scalar(
select(BuiltinToolProvider)
.where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
.limit(1)
)
@staticmethod
@ -648,14 +675,14 @@ class BuiltinToolManageService:
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
with sessionmaker(bind=db.engine).begin() as session:
custom_client_params = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
custom_client_params = session.scalar(
select(ToolOAuthTenantClient)
.where(
ToolOAuthTenantClient.tenant_id == tenant_id,
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
ToolOAuthTenantClient.provider == tool_provider.provider_name,
)
.first()
.limit(1)
)
# if the record does not exist, create a basic record
@ -692,14 +719,14 @@ class BuiltinToolManageService:
"""
with Session(db.engine) as session:
tool_provider = ToolProviderID(provider)
custom_oauth_client_params: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
custom_oauth_client_params = session.scalar(
select(ToolOAuthTenantClient)
.where(
ToolOAuthTenantClient.tenant_id == tenant_id,
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
ToolOAuthTenantClient.provider == tool_provider.provider_name,
)
.first()
.limit(1)
)
if custom_oauth_client_params is None:
return {}

View File

@ -84,6 +84,7 @@ class WorkflowToolManageService:
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
logger.warning(e, exc_info=True)
raise ValueError(str(e))
with Session(db.engine, expire_on_commit=False) as session, session.begin():

View File

@ -3,9 +3,9 @@ import logging
import time as _time
import uuid
from collections.abc import Mapping
from typing import Any
from typing import Any, TypedDict
from sqlalchemy import desc, func
from sqlalchemy import delete, desc, func, select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@ -42,6 +42,10 @@ from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
class VerifyCredentialsResult(TypedDict):
verified: bool
class TriggerProviderService:
"""Service for managing trigger providers and credentials"""
@ -69,27 +73,28 @@ class TriggerProviderService:
workflows_in_use_map: dict[str, int] = {}
with Session(db.engine, expire_on_commit=False) as session:
# Get all subscriptions
subscriptions_db = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
subscriptions_db = session.scalars(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.provider_id == str(provider_id),
)
.order_by(desc(TriggerSubscription.created_at))
.all()
)
).all()
subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db]
if not subscriptions:
return []
usage_counts = (
session.query(
usage_counts = session.execute(
select(
WorkflowPluginTrigger.subscription_id,
func.count(func.distinct(WorkflowPluginTrigger.app_id)).label("app_count"),
)
.filter(
.where(
WorkflowPluginTrigger.tenant_id == tenant_id,
WorkflowPluginTrigger.subscription_id.in_([s.id for s in subscriptions]),
)
.group_by(WorkflowPluginTrigger.subscription_id)
.all()
)
).all()
workflows_in_use_map = {str(row.subscription_id): int(row.app_count) for row in usage_counts}
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
@ -152,9 +157,13 @@ class TriggerProviderService:
with redis_client.lock(lock_key, timeout=20):
# Check provider count limit
provider_count = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
.count()
session.scalar(
select(func.count(TriggerSubscription.id)).where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.provider_id == str(provider_id),
)
)
or 0
)
if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__:
@ -164,10 +173,14 @@ class TriggerProviderService:
)
# Check if name already exists
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
existing = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.provider_id == str(provider_id),
TriggerSubscription.name == name,
)
.limit(1)
)
if existing:
raise ValueError(f"Credential name '{name}' already exists for this provider")
@ -244,8 +257,13 @@ class TriggerProviderService:
# Use distributed lock to prevent race conditions on the same subscription
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
with redis_client.lock(lock_key, timeout=20):
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
subscription = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.id == subscription_id,
)
.limit(1)
)
if not subscription:
raise ValueError(f"Trigger subscription {subscription_id} not found")
@ -255,10 +273,14 @@ class TriggerProviderService:
# Check for name uniqueness if name is being updated
if name is not None and name != subscription.name:
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
existing = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.provider_id == str(provider_id),
TriggerSubscription.name == name,
)
.limit(1)
)
if existing:
raise ValueError(f"Subscription name '{name}' already exists for this provider")
@ -316,11 +338,18 @@ class TriggerProviderService:
with Session(db.engine, expire_on_commit=False) as session:
subscription: TriggerSubscription | None = None
if subscription_id:
subscription = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
subscription = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.id == subscription_id,
)
.limit(1)
)
else:
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id).first()
subscription = session.scalar(
select(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant_id).limit(1)
)
if subscription:
provider_controller = TriggerManager.get_trigger_provider(
tenant_id, TriggerProviderID(subscription.provider_id)
@ -349,8 +378,13 @@ class TriggerProviderService:
:param subscription_id: Subscription instance ID
:return: Success response
"""
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
subscription = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.id == subscription_id,
)
.limit(1)
)
if not subscription:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
@ -402,7 +436,14 @@ class TriggerProviderService:
:return: New token info
"""
with sessionmaker(bind=db.engine).begin() as session:
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
subscription = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.id == subscription_id,
)
.limit(1)
)
if not subscription:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
@ -475,8 +516,13 @@ class TriggerProviderService:
now_ts: int = int(now if now is not None else _time.time())
with sessionmaker(bind=db.engine).begin() as session:
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
subscription = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.id == subscription_id,
)
.limit(1)
)
if subscription is None:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
@ -552,15 +598,15 @@ class TriggerProviderService:
tenant_id=tenant_id, provider_id=provider_id
)
with Session(db.engine, expire_on_commit=False) as session:
tenant_client: TriggerOAuthTenantClient | None = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
enabled=True,
tenant_client = session.scalar(
select(TriggerOAuthTenantClient)
.where(
TriggerOAuthTenantClient.tenant_id == tenant_id,
TriggerOAuthTenantClient.provider == provider_id.provider_name,
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
TriggerOAuthTenantClient.enabled.is_(True),
)
.first()
.limit(1)
)
oauth_params: Mapping[str, Any] | None = None
@ -578,10 +624,13 @@ class TriggerProviderService:
return None
# Check for system-level OAuth client
system_client: TriggerOAuthSystemClient | None = (
session.query(TriggerOAuthSystemClient)
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
.first()
system_client = session.scalar(
select(TriggerOAuthSystemClient)
.where(
TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id,
TriggerOAuthSystemClient.provider == provider_id.provider_name,
)
.limit(1)
)
if system_client:
@ -602,10 +651,13 @@ class TriggerProviderService:
if not is_verified:
return False
with Session(db.engine, expire_on_commit=False) as session:
system_client: TriggerOAuthSystemClient | None = (
session.query(TriggerOAuthSystemClient)
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
.first()
system_client = session.scalar(
select(TriggerOAuthSystemClient)
.where(
TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id,
TriggerOAuthSystemClient.provider == provider_id.provider_name,
)
.limit(1)
)
return system_client is not None
@ -636,14 +688,14 @@ class TriggerProviderService:
with sessionmaker(bind=db.engine).begin() as session:
# Find existing custom client params
custom_client = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
custom_client = session.scalar(
select(TriggerOAuthTenantClient)
.where(
TriggerOAuthTenantClient.tenant_id == tenant_id,
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
TriggerOAuthTenantClient.provider == provider_id.provider_name,
)
.first()
.limit(1)
)
# Create new record if doesn't exist
@ -690,14 +742,14 @@ class TriggerProviderService:
:return: Masked OAuth client parameters
"""
with Session(db.engine) as session:
custom_client = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
custom_client = session.scalar(
select(TriggerOAuthTenantClient)
.where(
TriggerOAuthTenantClient.tenant_id == tenant_id,
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
TriggerOAuthTenantClient.provider == provider_id.provider_name,
)
.first()
.limit(1)
)
if custom_client is None:
@ -727,11 +779,15 @@ class TriggerProviderService:
:return: Success response
"""
with sessionmaker(bind=db.engine).begin() as session:
session.query(TriggerOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
).delete()
session.execute(
delete(TriggerOAuthTenantClient)
.where(
TriggerOAuthTenantClient.tenant_id == tenant_id,
TriggerOAuthTenantClient.provider == provider_id.provider_name,
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
)
.execution_options(synchronize_session=False)
)
return {"result": "success"}
@ -745,15 +801,15 @@ class TriggerProviderService:
:return: True if enabled, False otherwise
"""
with Session(db.engine, expire_on_commit=False) as session:
custom_client = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
enabled=True,
custom_client = session.scalar(
select(TriggerOAuthTenantClient)
.where(
TriggerOAuthTenantClient.tenant_id == tenant_id,
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
TriggerOAuthTenantClient.provider == provider_id.provider_name,
TriggerOAuthTenantClient.enabled.is_(True),
)
.first()
.limit(1)
)
return custom_client is not None
@ -763,7 +819,9 @@ class TriggerProviderService:
Get a trigger subscription by the endpoint ID.
"""
with Session(db.engine, expire_on_commit=False) as session:
subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first()
subscription = session.scalar(
select(TriggerSubscription).where(TriggerSubscription.endpoint_id == endpoint_id).limit(1)
)
if not subscription:
return None
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
@ -792,7 +850,7 @@ class TriggerProviderService:
provider_id: TriggerProviderID,
subscription_id: str,
credentials: dict[str, Any],
) -> dict[str, Any]:
) -> VerifyCredentialsResult:
"""
Verify credentials for an existing subscription without updating it.

View File

@ -19,7 +19,7 @@ from graphon.variables.segments import (
)
from graphon.variables.types import SegmentType
from graphon.variables.utils import dumps_with_segments
from sqlalchemy import Engine, orm, select
from sqlalchemy import Engine, delete, orm, select
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session, sessionmaker
@ -222,11 +222,10 @@ class WorkflowDraftVariableService:
)
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
return (
self._session.query(WorkflowDraftVariable)
return self._session.scalar(
select(WorkflowDraftVariable)
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
.where(WorkflowDraftVariable.id == variable_id)
.first()
)
def get_draft_variables_by_selectors(
@ -254,20 +253,21 @@ class WorkflowDraftVariableService:
# Alternatively, a `SELECT` statement could be constructed for each selector and
# combined using `UNION` to fetch all rows.
# Benchmarking indicates that both approaches yield comparable performance.
query = (
self._session.query(WorkflowDraftVariable)
.options(
orm.selectinload(WorkflowDraftVariable.variable_file).selectinload(
WorkflowDraftVariableFile.upload_file
return list(
self._session.scalars(
select(WorkflowDraftVariable)
.options(
orm.selectinload(WorkflowDraftVariable.variable_file).selectinload(
WorkflowDraftVariableFile.upload_file
)
)
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.user_id == user_id,
or_(*ors),
)
)
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.user_id == user_id,
or_(*ors),
)
)
return query.all()
def list_variables_without_values(
self, app_id: str, page: int, limit: int, user_id: str
@ -277,18 +277,21 @@ class WorkflowDraftVariableService:
WorkflowDraftVariable.user_id == user_id,
]
total = None
query = self._session.query(WorkflowDraftVariable).where(*criteria)
base_stmt = select(WorkflowDraftVariable).where(*criteria)
if page == 1:
total = query.count()
variables = (
# Do not load the `value` field
query.options(
orm.defer(WorkflowDraftVariable.value, raiseload=True),
from sqlalchemy import func as sa_func
total = self._session.scalar(select(sa_func.count()).select_from(base_stmt.subquery()))
variables = list(
self._session.scalars(
# Do not load the `value` field
base_stmt.options(
orm.defer(WorkflowDraftVariable.value, raiseload=True),
)
.order_by(WorkflowDraftVariable.created_at.desc())
.limit(limit)
.offset((page - 1) * limit)
)
.order_by(WorkflowDraftVariable.created_at.desc())
.limit(limit)
.offset((page - 1) * limit)
.all()
)
return WorkflowDraftVariableList(variables=variables, total=total)
@ -299,11 +302,13 @@ class WorkflowDraftVariableService:
WorkflowDraftVariable.node_id == node_id,
WorkflowDraftVariable.user_id == user_id,
]
query = self._session.query(WorkflowDraftVariable).where(*criteria)
variables = (
query.options(orm.selectinload(WorkflowDraftVariable.variable_file))
.order_by(WorkflowDraftVariable.created_at.desc())
.all()
variables = list(
self._session.scalars(
select(WorkflowDraftVariable)
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
.where(*criteria)
.order_by(WorkflowDraftVariable.created_at.desc())
)
)
return WorkflowDraftVariableList(variables=variables)
@ -326,8 +331,8 @@ class WorkflowDraftVariableService:
return self._get_variable(app_id, node_id, name, user_id=user_id)
def _get_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
return (
self._session.query(WorkflowDraftVariable)
return self._session.scalar(
select(WorkflowDraftVariable)
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
.where(
WorkflowDraftVariable.app_id == app_id,
@ -335,7 +340,6 @@ class WorkflowDraftVariableService:
WorkflowDraftVariable.name == name,
WorkflowDraftVariable.user_id == user_id,
)
.first()
)
def update_variable(
@ -488,20 +492,20 @@ class WorkflowDraftVariableService:
self._session.delete(variable)
def delete_user_workflow_variables(self, app_id: str, user_id: str):
(
self._session.query(WorkflowDraftVariable)
self._session.execute(
delete(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.user_id == user_id,
)
.delete(synchronize_session=False)
.execution_options(synchronize_session=False)
)
def delete_app_workflow_variables(self, app_id: str):
(
self._session.query(WorkflowDraftVariable)
self._session.execute(
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.app_id == app_id)
.delete(synchronize_session=False)
.execution_options(synchronize_session=False)
)
def delete_workflow_draft_variable_file(self, deletions: list[DraftVarFileDeletion]):
@ -540,14 +544,14 @@ class WorkflowDraftVariableService:
return self._delete_node_variables(app_id, node_id, user_id=user_id)
def _delete_node_variables(self, app_id: str, node_id: str, user_id: str):
(
self._session.query(WorkflowDraftVariable)
self._session.execute(
delete(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
WorkflowDraftVariable.user_id == user_id,
)
.delete(synchronize_session=False)
.execution_options(synchronize_session=False)
)
def _get_conversation_id_from_draft_variable(self, app_id: str, user_id: str) -> str | None:
@ -588,13 +592,11 @@ class WorkflowDraftVariableService:
conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id, account_id)
if conv_id is not None:
conversation = (
self._session.query(Conversation)
.where(
conversation = self._session.scalar(
select(Conversation).where(
Conversation.id == conv_id,
Conversation.app_id == workflow.app_id,
)
.first()
)
# Only return the conversation ID if it exists and is valid (has a correspond conversation record in DB).
if conversation is not None:

View File

@ -1512,14 +1512,12 @@ class WorkflowService:
# Don't use workflow.tool_published as it's not accurate for specific workflow versions
# Check if there's a tool provider using this specific workflow version
tool_provider = (
session.query(WorkflowToolProvider)
.where(
tool_provider = session.scalar(
select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == workflow.tenant_id,
WorkflowToolProvider.app_id == workflow.app_id,
WorkflowToolProvider.version == workflow.version,
)
.first()
)
if tool_provider:

View File

@ -5,6 +5,7 @@ from typing import Any, Protocol
import click
from celery import current_app, shared_task
from sqlalchemy import select
from configs import dify_config
from core.db.session_factory import session_factory
@ -53,11 +54,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
Usage: _document_indexing(dataset_id, document_ids)
"""
documents = []
start_at = time.perf_counter()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
return
@ -79,8 +79,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
)
except Exception as e:
for document_id in document_ids:
document = (
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
if document:
document.indexing_status = IndexingStatus.ERROR
@ -92,8 +92,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
# Phase 1: Update status to parsing (short transaction)
with session_factory.create_session() as session, session.begin():
documents = (
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
documents: list[Document] = list(
session.scalars(
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
).all()
)
for document in documents:
@ -122,7 +124,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
# Trigger summary index generation for completed documents if enabled
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
# Re-query dataset to get latest summary_index_setting (in case it was updated)
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
logger.warning("Dataset %s not found after indexing", dataset_id)
return
@ -134,10 +136,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
session.expire_all()
# Check each document's indexing status and trigger summary generation if completed
documents = (
session.query(Document)
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
.all()
documents = list(
session.scalars(
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
).all()
)
for document in documents:

View File

@ -6,7 +6,7 @@ from typing import Any, cast
import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import delete
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
@ -99,7 +99,11 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
def _delete_app_model_configs(tenant_id: str, app_id: str):
def del_model_config(session, model_config_id: str):
session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
session.execute(
delete(AppModelConfig)
.where(AppModelConfig.id == model_config_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from app_model_configs where app_id=:app_id limit 1000""",
@ -111,7 +115,7 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
def _delete_app_site(tenant_id: str, app_id: str):
def del_site(session, site_id: str):
session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
session.execute(delete(Site).where(Site.id == site_id).execution_options(synchronize_session=False))
_delete_records(
"""select id from sites where app_id=:app_id limit 1000""",
@ -123,7 +127,9 @@ def _delete_app_site(tenant_id: str, app_id: str):
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def del_mcp_server(session, mcp_server_id: str):
session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
session.execute(
delete(AppMCPServer).where(AppMCPServer.id == mcp_server_id).execution_options(synchronize_session=False)
)
_delete_records(
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
@ -136,12 +142,14 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(session, api_token_id: str):
# Fetch token details for cache invalidation
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
token_obj = session.scalar(select(ApiToken).where(ApiToken.id == api_token_id).limit(1))
if token_obj:
# Invalidate cache before deletion
ApiTokenCache.delete(token_obj.token, token_obj.type)
session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
session.execute(
delete(ApiToken).where(ApiToken.id == api_token_id).execution_options(synchronize_session=False)
)
_delete_records(
"""select id from api_tokens where app_id=:app_id limit 1000""",
@ -153,7 +161,9 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
def _delete_installed_apps(tenant_id: str, app_id: str):
def del_installed_app(session, installed_app_id: str):
session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
session.execute(
delete(InstalledApp).where(InstalledApp.id == installed_app_id).execution_options(synchronize_session=False)
)
_delete_records(
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -165,7 +175,11 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
def _delete_recommended_apps(tenant_id: str, app_id: str):
def del_recommended_app(session, recommended_app_id: str):
session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
session.execute(
delete(RecommendedApp)
.where(RecommendedApp.id == recommended_app_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from recommended_apps where app_id=:app_id limit 1000""",
@ -177,8 +191,10 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
def _delete_app_annotation_data(tenant_id: str, app_id: str):
def del_annotation_hit_history(session, annotation_hit_history_id: str):
session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
synchronize_session=False
session.execute(
delete(AppAnnotationHitHistory)
.where(AppAnnotationHitHistory.id == annotation_hit_history_id)
.execution_options(synchronize_session=False)
)
_delete_records(
@ -189,8 +205,10 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
)
def del_annotation_setting(session, annotation_setting_id: str):
session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
synchronize_session=False
session.execute(
delete(AppAnnotationSetting)
.where(AppAnnotationSetting.id == annotation_setting_id)
.execution_options(synchronize_session=False)
)
_delete_records(
@ -203,7 +221,11 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def del_dataset_join(session, dataset_join_id: str):
session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
session.execute(
delete(AppDatasetJoin)
.where(AppDatasetJoin.id == dataset_join_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
@ -215,7 +237,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def _delete_app_workflows(tenant_id: str, app_id: str):
def del_workflow(session, workflow_id: str):
session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
session.execute(delete(Workflow).where(Workflow.id == workflow_id).execution_options(synchronize_session=False))
_delete_records(
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -255,7 +277,11 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(session, workflow_app_log_id: str):
session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
session.execute(
delete(WorkflowAppLog)
.where(WorkflowAppLog.id == workflow_app_log_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -267,8 +293,10 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
def del_workflow_archive_log(session, workflow_archive_log_id: str):
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
synchronize_session=False
session.execute(
delete(WorkflowArchiveLog)
.where(WorkflowArchiveLog.id == workflow_archive_log_id)
.execution_options(synchronize_session=False)
)
_delete_records(
@ -306,10 +334,14 @@ def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
def _delete_app_conversations(tenant_id: str, app_id: str):
def del_conversation(session, conversation_id: str):
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
session.execute(
delete(PinnedConversation)
.where(PinnedConversation.conversation_id == conversation_id)
.execution_options(synchronize_session=False)
)
session.execute(
delete(Conversation).where(Conversation.id == conversation_id).execution_options(synchronize_session=False)
)
session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
_delete_records(
"""select id from conversations where app_id=:app_id limit 1000""",
@ -329,17 +361,35 @@ def _delete_conversation_variables(*, app_id: str):
def _delete_app_messages(tenant_id: str, app_id: str):
def del_message(session, message_id: str):
session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
synchronize_session=False
session.execute(
delete(MessageFeedback)
.where(MessageFeedback.message_id == message_id)
.execution_options(synchronize_session=False)
)
session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
synchronize_session=False
session.execute(
delete(MessageAnnotation)
.where(MessageAnnotation.message_id == message_id)
.execution_options(synchronize_session=False)
)
session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
session.query(Message).where(Message.id == message_id).delete()
session.execute(
delete(MessageChain)
.where(MessageChain.message_id == message_id)
.execution_options(synchronize_session=False)
)
session.execute(
delete(MessageAgentThought)
.where(MessageAgentThought.message_id == message_id)
.execution_options(synchronize_session=False)
)
session.execute(
delete(MessageFile).where(MessageFile.message_id == message_id).execution_options(synchronize_session=False)
)
session.execute(
delete(SavedMessage)
.where(SavedMessage.message_id == message_id)
.execution_options(synchronize_session=False)
)
session.execute(delete(Message).where(Message.id == message_id).execution_options(synchronize_session=False))
_delete_records(
"""select id from messages where app_id=:app_id limit 1000""",
@ -351,8 +401,10 @@ def _delete_app_messages(tenant_id: str, app_id: str):
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def del_tool_provider(session, tool_provider_id: str):
session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
synchronize_session=False
session.execute(
delete(WorkflowToolProvider)
.where(WorkflowToolProvider.id == tool_provider_id)
.execution_options(synchronize_session=False)
)
_delete_records(
@ -365,7 +417,9 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def del_tag_binding(session, tag_binding_id: str):
session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
session.execute(
delete(TagBinding).where(TagBinding.id == tag_binding_id).execution_options(synchronize_session=False)
)
_delete_records(
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
@ -377,7 +431,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def _delete_end_users(tenant_id: str, app_id: str):
def del_end_user(session, end_user_id: str):
session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
session.execute(delete(EndUser).where(EndUser.id == end_user_id).execution_options(synchronize_session=False))
_delete_records(
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -389,7 +443,11 @@ def _delete_end_users(tenant_id: str, app_id: str):
def _delete_trace_app_configs(tenant_id: str, app_id: str):
def del_trace_app_config(session, trace_app_config_id: str):
session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
session.execute(
delete(TraceAppConfig)
.where(TraceAppConfig.id == trace_app_config_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from trace_app_config where app_id=:app_id limit 1000""",
@ -545,7 +603,9 @@ def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
def _delete_app_triggers(tenant_id: str, app_id: str):
def del_app_trigger(session, trigger_id: str):
session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
session.execute(
delete(AppTrigger).where(AppTrigger.id == trigger_id).execution_options(synchronize_session=False)
)
_delete_records(
"""select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -557,8 +617,10 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
def del_plugin_trigger(session, trigger_id: str):
session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
synchronize_session=False
session.execute(
delete(WorkflowPluginTrigger)
.where(WorkflowPluginTrigger.id == trigger_id)
.execution_options(synchronize_session=False)
)
_delete_records(
@ -571,8 +633,10 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
def del_webhook_trigger(session, trigger_id: str):
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
synchronize_session=False
session.execute(
delete(WorkflowWebhookTrigger)
.where(WorkflowWebhookTrigger.id == trigger_id)
.execution_options(synchronize_session=False)
)
_delete_records(
@ -585,7 +649,11 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
def del_schedule_plan(session, plan_id: str):
session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
session.execute(
delete(WorkflowSchedulePlan)
.where(WorkflowSchedulePlan.id == plan_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -597,7 +665,11 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
def del_trigger_log(session, log_id: str):
session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
session.execute(
delete(WorkflowTriggerLog)
.where(WorkflowTriggerLog.id == log_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",

View File

@ -557,11 +557,9 @@ class TestPauseStatePersistenceLayerTestContainers:
self.session.refresh(self.test_workflow_run)
assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING
pause_states = (
self.session.query(WorkflowPauseModel)
.filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id)
.all()
)
pause_states = self.session.scalars(
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id)
).all()
assert len(pause_states) == 0
def test_layer_requires_initialization(self, db_session_with_containers):

View File

@ -0,0 +1,149 @@
"""
Integration tests for Conversation.inputs and Message.inputs tenant resolution.
Migrated from unit_tests/models/test_model.py, replacing db.session.scalar monkeypatching
with a real App in PostgreSQL so the _resolve_app_tenant_id lookup executes against the DB.
"""
from collections.abc import Generator
from unittest.mock import patch
from uuid import uuid4
import pytest
from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod
from sqlalchemy.orm import Session
from core.workflow.file_reference import build_file_reference
from models.model import App, AppMode, Conversation, Message
def _build_local_file_mapping(record_id: str, *, tenant_id: str | None = None) -> dict:
mapping: dict = {
"dify_model_identity": FILE_MODEL_IDENTITY,
"transfer_method": FileTransferMethod.LOCAL_FILE,
"reference": build_file_reference(record_id=record_id),
"type": "document",
"filename": "example.txt",
"extension": ".txt",
"mime_type": "text/plain",
"size": 1,
}
if tenant_id is not None:
mapping["tenant_id"] = tenant_id
return mapping
class TestConversationMessageInputsTenantResolution:
"""Integration tests for Conversation/Message.inputs tenant resolution via real DB lookup."""
@pytest.fixture(autouse=True)
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
"""Automatically rollback session changes after each test."""
yield
db_session_with_containers.rollback()
def _create_app(self, db_session: Session) -> App:
tenant_id = str(uuid4())
app = App(
tenant_id=tenant_id,
name=f"App {uuid4()}",
mode=AppMode.CHAT,
enable_site=False,
enable_api=True,
is_demo=False,
is_public=False,
is_universal=False,
created_by=str(uuid4()),
updated_by=str(uuid4()),
)
db_session.add(app)
db_session.flush()
return app
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
def test_inputs_resolves_tenant_via_db_for_local_file(
self,
db_session_with_containers: Session,
owner_cls: type,
) -> None:
"""Inputs resolves tenant_id from real App row when file mapping has no tenant_id."""
app = self._create_app(db_session_with_containers)
build_calls: list[tuple[dict, str]] = []
def fake_build_from_mapping(
*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller
):
build_calls.append((dict(mapping), tenant_id))
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping):
owner = owner_cls(app_id=app.id)
owner.inputs = {"file": _build_local_file_mapping("upload-1")}
restored_inputs = owner.inputs
# The tenant_id should come from the real App row in the DB
assert restored_inputs["file"] == {"tenant_id": app.tenant_id, "upload_file_id": "upload-1"}
assert len(build_calls) == 1
assert build_calls[0][1] == app.tenant_id
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
def test_inputs_uses_serialized_tenant_id_skipping_db_lookup(
self,
db_session_with_containers: Session,
owner_cls: type,
) -> None:
"""Inputs uses tenant_id from the file mapping payload without hitting the DB."""
app = self._create_app(db_session_with_containers)
payload_tenant_id = "tenant-from-payload"
build_calls: list[tuple[dict, str]] = []
def fake_build_from_mapping(
*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller
):
build_calls.append((dict(mapping), tenant_id))
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping):
owner = owner_cls(app_id=app.id)
owner.inputs = {"file": _build_local_file_mapping("upload-1", tenant_id=payload_tenant_id)}
restored_inputs = owner.inputs
assert restored_inputs["file"] == {"tenant_id": payload_tenant_id, "upload_file_id": "upload-1"}
assert len(build_calls) == 1
assert build_calls[0][1] == payload_tenant_id
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
def test_inputs_resolves_tenant_for_file_list(
self,
db_session_with_containers: Session,
owner_cls: type,
) -> None:
"""Inputs resolves tenant_id for a list of file mappings."""
app = self._create_app(db_session_with_containers)
build_calls: list[tuple[dict, str]] = []
def fake_build_from_mapping(
*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller
):
build_calls.append((dict(mapping), tenant_id))
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping):
owner = owner_cls(app_id=app.id)
owner.inputs = {
"files": [
_build_local_file_mapping("upload-1"),
_build_local_file_mapping("upload-2"),
]
}
restored_inputs = owner.inputs
assert len(build_calls) == 2
assert all(call[1] == app.tenant_id for call in build_calls)
assert restored_inputs["files"] == [
{"tenant_id": app.tenant_id, "upload_file_id": "upload-1"},
{"tenant_id": app.tenant_id, "upload_file_id": "upload-2"},
]

View File

@ -0,0 +1,314 @@
"""
Integration tests for Conversation.status_count and Site.generate_code model properties.
Migrated from unit_tests/models/test_app_models.py TestConversationStatusCount and
test_site_generate_code, replacing db.session.scalars mocks with real PostgreSQL queries.
"""
from collections.abc import Generator
from uuid import uuid4
import pytest
from graphon.enums import WorkflowExecutionStatus
from sqlalchemy.orm import Session
from models.enums import ConversationFromSource, InvokeFrom
from models.model import App, AppMode, Conversation, Message, Site
from models.workflow import Workflow, WorkflowRun, WorkflowRunTriggeredFrom, WorkflowType
class TestConversationStatusCount:
"""Integration tests for Conversation.status_count property."""
@pytest.fixture(autouse=True)
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
"""Automatically rollback session changes after each test."""
yield
db_session_with_containers.rollback()
def _create_app(self, db_session: Session, tenant_id: str, created_by: str) -> App:
app = App(
tenant_id=tenant_id,
name=f"App {uuid4()}",
mode=AppMode.ADVANCED_CHAT,
enable_site=False,
enable_api=True,
is_demo=False,
is_public=False,
is_universal=False,
created_by=created_by,
updated_by=created_by,
)
db_session.add(app)
db_session.flush()
return app
def _create_conversation(self, db_session: Session, app: App) -> Conversation:
conversation = Conversation(
app_id=app.id,
mode=app.mode,
name=f"Conversation {uuid4()}",
summary="",
inputs={},
introduction="",
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from=InvokeFrom.WEB_APP,
from_source=ConversationFromSource.API,
dialogue_count=0,
is_deleted=False,
)
conversation.inputs = {}
db_session.add(conversation)
db_session.flush()
return conversation
def _create_workflow(self, db_session: Session, app: App, created_by: str) -> Workflow:
workflow = Workflow(
tenant_id=app.tenant_id,
app_id=app.id,
type=WorkflowType.CHAT,
version="draft",
graph="{}",
created_by=created_by,
)
workflow._features = "{}"
db_session.add(workflow)
db_session.flush()
return workflow
def _create_workflow_run(
self, db_session: Session, app: App, workflow: Workflow, status: WorkflowExecutionStatus, created_by: str
) -> WorkflowRun:
run = WorkflowRun(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
type=WorkflowType.CHAT,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
version="draft",
status=status,
created_by_role="account",
created_by=created_by,
)
db_session.add(run)
db_session.flush()
return run
def _create_message(
self, db_session: Session, app: App, conversation: Conversation, workflow_run_id: str | None = None
) -> Message:
message = Message(
app_id=app.id,
conversation_id=conversation.id,
_inputs={},
query="Test query",
message={"role": "user", "content": "Test query"},
answer="Test answer",
model_provider="openai",
model_id="gpt-4",
message_tokens=10,
message_unit_price=0,
answer_tokens=10,
answer_unit_price=0,
total_price=0,
currency="USD",
from_source=ConversationFromSource.API,
invoke_from=InvokeFrom.WEB_APP,
workflow_run_id=workflow_run_id,
)
db_session.add(message)
db_session.flush()
return message
def test_status_count_returns_none_when_no_messages(self, db_session_with_containers: Session) -> None:
"""status_count returns None when conversation has no messages with workflow_run_id."""
tenant_id = str(uuid4())
created_by = str(uuid4())
app = self._create_app(db_session_with_containers, tenant_id, created_by)
conversation = self._create_conversation(db_session_with_containers, app)
result = conversation.status_count
assert result is None
def test_status_count_returns_none_when_messages_have_no_workflow_run_id(
self, db_session_with_containers: Session
) -> None:
"""status_count returns None when messages exist but none have workflow_run_id."""
tenant_id = str(uuid4())
created_by = str(uuid4())
app = self._create_app(db_session_with_containers, tenant_id, created_by)
conversation = self._create_conversation(db_session_with_containers, app)
self._create_message(db_session_with_containers, app, conversation, workflow_run_id=None)
result = conversation.status_count
assert result is None
def test_status_count_counts_succeeded_workflow_run(self, db_session_with_containers: Session) -> None:
"""status_count correctly counts succeeded workflow runs."""
tenant_id = str(uuid4())
created_by = str(uuid4())
app = self._create_app(db_session_with_containers, tenant_id, created_by)
conversation = self._create_conversation(db_session_with_containers, app)
workflow = self._create_workflow(db_session_with_containers, app, created_by)
run = self._create_workflow_run(
db_session_with_containers, app, workflow, WorkflowExecutionStatus.SUCCEEDED, created_by
)
self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id)
result = conversation.status_count
assert result is not None
assert result["success"] == 1
assert result["failed"] == 0
assert result["partial_success"] == 0
assert result["paused"] == 0
def test_status_count_counts_failed_workflow_run(self, db_session_with_containers: Session) -> None:
"""status_count correctly counts failed workflow runs."""
tenant_id = str(uuid4())
created_by = str(uuid4())
app = self._create_app(db_session_with_containers, tenant_id, created_by)
conversation = self._create_conversation(db_session_with_containers, app)
workflow = self._create_workflow(db_session_with_containers, app, created_by)
run = self._create_workflow_run(
db_session_with_containers, app, workflow, WorkflowExecutionStatus.FAILED, created_by
)
self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id)
result = conversation.status_count
assert result is not None
assert result["success"] == 0
assert result["failed"] == 1
assert result["partial_success"] == 0
assert result["paused"] == 0
def test_status_count_counts_paused_workflow_run(self, db_session_with_containers: Session) -> None:
"""status_count correctly counts paused workflow runs."""
tenant_id = str(uuid4())
created_by = str(uuid4())
app = self._create_app(db_session_with_containers, tenant_id, created_by)
conversation = self._create_conversation(db_session_with_containers, app)
workflow = self._create_workflow(db_session_with_containers, app, created_by)
run = self._create_workflow_run(
db_session_with_containers, app, workflow, WorkflowExecutionStatus.PAUSED, created_by
)
self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id)
result = conversation.status_count
assert result is not None
assert result["success"] == 0
assert result["failed"] == 0
assert result["partial_success"] == 0
assert result["paused"] == 1
def test_status_count_multiple_statuses(self, db_session_with_containers: Session) -> None:
"""status_count counts multiple workflow runs with different statuses."""
tenant_id = str(uuid4())
created_by = str(uuid4())
app = self._create_app(db_session_with_containers, tenant_id, created_by)
conversation = self._create_conversation(db_session_with_containers, app)
workflow = self._create_workflow(db_session_with_containers, app, created_by)
for status in [
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
WorkflowExecutionStatus.PAUSED,
]:
run = self._create_workflow_run(db_session_with_containers, app, workflow, status, created_by)
self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id)
result = conversation.status_count
assert result is not None
assert result["success"] == 1
assert result["failed"] == 1
assert result["partial_success"] == 1
assert result["paused"] == 1
def test_status_count_filters_workflow_runs_by_app_id(self, db_session_with_containers: Session) -> None:
"""status_count excludes workflow runs belonging to a different app."""
tenant_id = str(uuid4())
created_by = str(uuid4())
app = self._create_app(db_session_with_containers, tenant_id, created_by)
other_app = self._create_app(db_session_with_containers, tenant_id, created_by)
conversation = self._create_conversation(db_session_with_containers, app)
workflow = self._create_workflow(db_session_with_containers, other_app, created_by)
# Workflow run belongs to other_app, not app
other_run = self._create_workflow_run(
db_session_with_containers, other_app, workflow, WorkflowExecutionStatus.SUCCEEDED, created_by
)
# Message references that run but is in a conversation under app
self._create_message(db_session_with_containers, app, conversation, workflow_run_id=other_run.id)
result = conversation.status_count
# The run should be excluded because app_id filter doesn't match
assert result is not None
assert result["success"] == 0
class TestSiteGenerateCode:
"""Integration tests for Site.generate_code static method."""
@pytest.fixture(autouse=True)
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
"""Automatically rollback session changes after each test."""
yield
db_session_with_containers.rollback()
def test_generate_code_returns_string_of_correct_length(self, db_session_with_containers: Session) -> None:
"""Site.generate_code returns a code string of the requested length."""
code = Site.generate_code(8)
assert isinstance(code, str)
assert len(code) == 8
def test_generate_code_avoids_duplicates(self, db_session_with_containers: Session) -> None:
"""Site.generate_code returns a code not already in use."""
tenant_id = str(uuid4())
app = App(
tenant_id=tenant_id,
name="Test App",
mode=AppMode.CHAT,
enable_site=True,
enable_api=False,
is_demo=False,
is_public=False,
is_universal=False,
created_by=str(uuid4()),
updated_by=str(uuid4()),
)
db_session_with_containers.add(app)
db_session_with_containers.flush()
site = Site(
app_id=app.id,
title="Test Site",
default_language="en-US",
customize_token_strategy="not_allow",
)
# Set an explicit code so generate_code must avoid it
site.code = "AAAAAAAA"
db_session_with_containers.add(site)
db_session_with_containers.flush()
code = Site.generate_code(8)
assert isinstance(code, str)
assert len(code) == 8
assert code != site.code

View File

@ -6,7 +6,7 @@ import pytest
import sqlalchemy as sa
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import exc as sa_exc
from sqlalchemy import insert
from sqlalchemy import insert, select
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
from sqlalchemy.sql.sqltypes import VARCHAR
@ -137,12 +137,12 @@ class TestEnumText:
session.commit()
with Session(engine_with_containers) as session:
user = session.query(_User).where(_User.id == admin_user_id).first()
user = session.scalar(select(_User).where(_User.id == admin_user_id).limit(1))
assert user.user_type == _UserType.admin
assert user.user_type_nullable is None
with Session(engine_with_containers) as session:
user = session.query(_User).where(_User.id == normal_user_id).first()
user = session.scalar(select(_User).where(_User.id == normal_user_id).limit(1))
assert user.user_type == _UserType.normal
assert user.user_type_nullable == _UserType.normal
@ -206,7 +206,7 @@ class TestEnumText:
with pytest.raises(ValueError) as exc:
with Session(engine_with_containers) as session:
_user = session.query(_User).where(_User.id == 1).first()
_user = session.scalar(select(_User).where(_User.id == 1).limit(1))
assert str(exc.value) == "'invalid' is not a valid _UserType"
@ -222,7 +222,7 @@ class TestEnumText:
session.commit()
with Session(engine_with_containers) as session:
records = session.query(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id).all()
records = session.scalars(select(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id)).all()
assert [record.model_type for record in records] == [
ModelType.LLM,

View File

@ -0,0 +1,170 @@
"""
Integration tests for WorkflowNodeExecutionModel.created_by_account and .created_by_end_user.
Migrated from unit_tests/models/test_workflow_trigger_log.py, replacing
monkeypatch.setattr(db.session, "scalar", ...) with real Account/EndUser rows
persisted in PostgreSQL so the db.session.get() call executes against the DB.
"""
from collections.abc import Generator
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from models.account import Account
from models.enums import CreatorUserRole
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
class TestWorkflowNodeExecutionModelCreatedBy:
"""Integration tests for WorkflowNodeExecutionModel creator lookup properties."""
@pytest.fixture(autouse=True)
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
"""Automatically rollback session changes after each test."""
yield
db_session_with_containers.rollback()
def _create_account(self, db_session: Session) -> Account:
account = Account(
name="Test Account",
email=f"test_{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
db_session.add(account)
db_session.flush()
return account
def _create_end_user(self, db_session: Session, tenant_id: str, app_id: str) -> EndUser:
end_user = EndUser(
tenant_id=tenant_id,
app_id=app_id,
type="service_api",
external_user_id=f"ext-{uuid4()}",
name="End User",
session_id=f"session-{uuid4()}",
)
end_user.is_anonymous = False
db_session.add(end_user)
db_session.flush()
return end_user
def _create_app(self, db_session: Session, tenant_id: str, created_by: str) -> App:
app = App(
tenant_id=tenant_id,
name=f"App {uuid4()}",
mode=AppMode.WORKFLOW,
enable_site=False,
enable_api=True,
is_demo=False,
is_public=False,
is_universal=False,
created_by=created_by,
updated_by=created_by,
)
db_session.add(app)
db_session.flush()
return app
def _make_execution(
self, tenant_id: str, app_id: str, created_by_role: str, created_by: str
) -> WorkflowNodeExecutionModel:
return WorkflowNodeExecutionModel(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=str(uuid4()),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=None,
index=1,
predecessor_node_id=None,
node_execution_id=None,
node_id="n1",
node_type="start",
title="Start",
inputs=None,
process_data=None,
outputs=None,
status="succeeded",
error=None,
elapsed_time=0.0,
execution_metadata=None,
created_by_role=created_by_role,
created_by=created_by,
)
def test_created_by_account_returns_account_when_role_is_account(self, db_session_with_containers: Session) -> None:
"""created_by_account returns the Account row when role is ACCOUNT."""
account = self._create_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, str(uuid4()), account.id)
execution = self._make_execution(
tenant_id=app.tenant_id,
app_id=app.id,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by=account.id,
)
result = execution.created_by_account
assert result is not None
assert result.id == account.id
def test_created_by_account_returns_none_when_role_is_end_user(self, db_session_with_containers: Session) -> None:
"""created_by_account returns None when role is END_USER, even if an Account exists."""
account = self._create_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, str(uuid4()), account.id)
execution = self._make_execution(
tenant_id=app.tenant_id,
app_id=app.id,
created_by_role=CreatorUserRole.END_USER.value,
created_by=account.id,
)
result = execution.created_by_account
assert result is None
def test_created_by_end_user_returns_end_user_when_role_is_end_user(
self, db_session_with_containers: Session
) -> None:
"""created_by_end_user returns the EndUser row when role is END_USER."""
account = self._create_account(db_session_with_containers)
tenant_id = str(uuid4())
app = self._create_app(db_session_with_containers, tenant_id, account.id)
end_user = self._create_end_user(db_session_with_containers, tenant_id, app.id)
execution = self._make_execution(
tenant_id=tenant_id,
app_id=app.id,
created_by_role=CreatorUserRole.END_USER.value,
created_by=end_user.id,
)
result = execution.created_by_end_user
assert result is not None
assert result.id == end_user.id
def test_created_by_end_user_returns_none_when_role_is_account(self, db_session_with_containers: Session) -> None:
"""created_by_end_user returns None when role is ACCOUNT, even if an EndUser exists."""
account = self._create_account(db_session_with_containers)
tenant_id = str(uuid4())
app = self._create_app(db_session_with_containers, tenant_id, account.id)
end_user = self._create_end_user(db_session_with_containers, tenant_id, app.id)
execution = self._make_execution(
tenant_id=tenant_id,
app_id=app.id,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by=end_user.id,
)
result = execution.created_by_end_user
assert result is None

View File

@ -0,0 +1,395 @@
"""Testcontainers integration tests for SQLAlchemyWorkflowNodeExecutionRepository."""
from __future__ import annotations
import json
from datetime import datetime
from decimal import Decimal
from uuid import uuid4
from graphon.entities import WorkflowNodeExecution
from graphon.enums import (
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from graphon.model_runtime.utils.encoders import jsonable_encoder
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.factory import OrderConfig
from models.account import Account, Tenant
from models.enums import CreatorUserRole
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
def _create_account_with_tenant(session: Session) -> Account:
tenant = Tenant(name="Test Workspace")
session.add(tenant)
session.flush()
account = Account(name="test", email=f"test-{uuid4()}@example.com")
session.add(account)
session.flush()
account._current_tenant = tenant
return account
def _make_repo(session: Session, account: Account, app_id: str) -> SQLAlchemyWorkflowNodeExecutionRepository:
engine = session.get_bind()
assert isinstance(engine, Engine)
return SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=sessionmaker(bind=engine, expire_on_commit=False),
user=account,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def _create_node_execution_model(
session: Session,
*,
tenant_id: str,
app_id: str,
workflow_id: str,
workflow_run_id: str,
index: int = 1,
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING,
) -> WorkflowNodeExecutionModel:
model = WorkflowNodeExecutionModel(
id=str(uuid4()),
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=workflow_run_id,
index=index,
predecessor_node_id=None,
node_execution_id=str(uuid4()),
node_id=f"node-{index}",
node_type=BuiltinNodeTypes.START,
title=f"Test Node {index}",
inputs='{"input_key": "input_value"}',
process_data='{"process_key": "process_value"}',
outputs='{"output_key": "output_value"}',
status=status,
error=None,
elapsed_time=1.5,
execution_metadata="{}",
created_at=datetime.now(),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
finished_at=None,
)
session.add(model)
session.flush()
return model
class TestSave:
def test_save_new_record(self, db_session_with_containers: Session) -> None:
account = _create_account_with_tenant(db_session_with_containers)
app_id = str(uuid4())
repo = _make_repo(db_session_with_containers, account, app_id)
execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_execution_id=str(uuid4()),
index=1,
predecessor_node_id=None,
node_id="node-1",
node_type=BuiltinNodeTypes.START,
title="Test Node",
inputs={"input_key": "input_value"},
process_data={"process_key": "process_value"},
outputs={"result": "success"},
status=WorkflowNodeExecutionStatus.RUNNING,
error=None,
elapsed_time=1.5,
metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100},
created_at=datetime.now(),
finished_at=None,
)
repo.save(execution)
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session:
saved = verify_session.get(WorkflowNodeExecutionModel, execution.id)
assert saved is not None
assert saved.tenant_id == account.current_tenant_id
assert saved.app_id == app_id
assert saved.node_id == "node-1"
assert saved.status == WorkflowNodeExecutionStatus.RUNNING
def test_save_updates_existing_record(self, db_session_with_containers: Session) -> None:
account = _create_account_with_tenant(db_session_with_containers)
repo = _make_repo(db_session_with_containers, account, str(uuid4()))
execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_execution_id=str(uuid4()),
index=1,
predecessor_node_id=None,
node_id="node-1",
node_type=BuiltinNodeTypes.START,
title="Test Node",
inputs=None,
process_data=None,
outputs=None,
status=WorkflowNodeExecutionStatus.RUNNING,
error=None,
elapsed_time=0.0,
metadata=None,
created_at=datetime.now(),
finished_at=None,
)
repo.save(execution)
execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
execution.elapsed_time = 2.5
repo.save(execution)
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session:
saved = verify_session.get(WorkflowNodeExecutionModel, execution.id)
assert saved is not None
assert saved.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert saved.elapsed_time == 2.5
class TestGetByWorkflowExecution:
def test_returns_executions_ordered(self, db_session_with_containers: Session) -> None:
account = _create_account_with_tenant(db_session_with_containers)
tenant_id = account.current_tenant_id
app_id = str(uuid4())
workflow_id = str(uuid4())
workflow_run_id = str(uuid4())
repo = _make_repo(db_session_with_containers, account, app_id)
_create_node_execution_model(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
index=1,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
)
_create_node_execution_model(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
index=2,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
)
db_session_with_containers.commit()
order_config = OrderConfig(order_by=["index"], order_direction="desc")
result = repo.get_by_workflow_execution(
workflow_execution_id=workflow_run_id,
order_config=order_config,
)
assert len(result) == 2
assert result[0].index == 2
assert result[1].index == 1
assert all(isinstance(r, WorkflowNodeExecution) for r in result)
def test_excludes_paused_executions(self, db_session_with_containers: Session) -> None:
account = _create_account_with_tenant(db_session_with_containers)
tenant_id = account.current_tenant_id
app_id = str(uuid4())
workflow_id = str(uuid4())
workflow_run_id = str(uuid4())
repo = _make_repo(db_session_with_containers, account, app_id)
_create_node_execution_model(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
index=1,
status=WorkflowNodeExecutionStatus.RUNNING,
)
_create_node_execution_model(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
index=2,
status=WorkflowNodeExecutionStatus.PAUSED,
)
db_session_with_containers.commit()
result = repo.get_by_workflow_execution(workflow_execution_id=workflow_run_id)
assert len(result) == 1
assert result[0].index == 1
class TestToDbModel:
def test_converts_domain_to_db_model(self, db_session_with_containers: Session) -> None:
account = _create_account_with_tenant(db_session_with_containers)
app_id = str(uuid4())
repo = _make_repo(db_session_with_containers, account, app_id)
domain_model = WorkflowNodeExecution(
id="test-id",
workflow_id="test-workflow-id",
node_execution_id="test-node-execution-id",
workflow_execution_id="test-workflow-run-id",
index=1,
predecessor_node_id="test-predecessor-id",
node_id="test-node-id",
node_type=BuiltinNodeTypes.START,
title="Test Node",
inputs={"input_key": "input_value"},
process_data={"process_key": "process_value"},
outputs={"output_key": "output_value"},
status=WorkflowNodeExecutionStatus.RUNNING,
error=None,
elapsed_time=1.5,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: Decimal("0.0"),
},
created_at=datetime.now(),
finished_at=None,
)
db_model = repo._to_db_model(domain_model)
assert isinstance(db_model, WorkflowNodeExecutionModel)
assert db_model.id == domain_model.id
assert db_model.tenant_id == account.current_tenant_id
assert db_model.app_id == app_id
assert db_model.workflow_id == domain_model.workflow_id
assert db_model.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
assert db_model.workflow_run_id == domain_model.workflow_execution_id
assert db_model.index == domain_model.index
assert db_model.predecessor_node_id == domain_model.predecessor_node_id
assert db_model.node_execution_id == domain_model.node_execution_id
assert db_model.node_id == domain_model.node_id
assert db_model.node_type == domain_model.node_type
assert db_model.title == domain_model.title
assert db_model.inputs_dict == domain_model.inputs
assert db_model.process_data_dict == domain_model.process_data
assert db_model.outputs_dict == domain_model.outputs
assert db_model.execution_metadata_dict == jsonable_encoder(domain_model.metadata)
assert db_model.status == domain_model.status
assert db_model.error == domain_model.error
assert db_model.elapsed_time == domain_model.elapsed_time
assert db_model.created_at == domain_model.created_at
assert db_model.created_by_role == CreatorUserRole.ACCOUNT
assert db_model.created_by == account.id
assert db_model.finished_at == domain_model.finished_at
class TestToDomainModel:
def test_converts_db_to_domain_model(self, db_session_with_containers: Session) -> None:
account = _create_account_with_tenant(db_session_with_containers)
app_id = str(uuid4())
repo = _make_repo(db_session_with_containers, account, app_id)
inputs_dict = {"input_key": "input_value"}
process_data_dict = {"process_key": "process_value"}
outputs_dict = {"output_key": "output_value"}
metadata_dict = {str(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS): 100}
now = datetime.now()
db_model = WorkflowNodeExecutionModel()
db_model.id = "test-id"
db_model.tenant_id = account.current_tenant_id
db_model.app_id = app_id
db_model.workflow_id = "test-workflow-id"
db_model.triggered_from = "workflow-run"
db_model.workflow_run_id = "test-workflow-run-id"
db_model.index = 1
db_model.predecessor_node_id = "test-predecessor-id"
db_model.node_execution_id = "test-node-execution-id"
db_model.node_id = "test-node-id"
db_model.node_type = BuiltinNodeTypes.START
db_model.title = "Test Node"
db_model.inputs = json.dumps(inputs_dict)
db_model.process_data = json.dumps(process_data_dict)
db_model.outputs = json.dumps(outputs_dict)
db_model.status = WorkflowNodeExecutionStatus.RUNNING
db_model.error = None
db_model.elapsed_time = 1.5
db_model.execution_metadata = json.dumps(metadata_dict)
db_model.created_at = now
db_model.created_by_role = "account"
db_model.created_by = account.id
db_model.finished_at = None
domain_model = repo._to_domain_model(db_model)
assert isinstance(domain_model, WorkflowNodeExecution)
assert domain_model.id == "test-id"
assert domain_model.workflow_id == "test-workflow-id"
assert domain_model.workflow_execution_id == "test-workflow-run-id"
assert domain_model.index == 1
assert domain_model.predecessor_node_id == "test-predecessor-id"
assert domain_model.node_execution_id == "test-node-execution-id"
assert domain_model.node_id == "test-node-id"
assert domain_model.node_type == BuiltinNodeTypes.START
assert domain_model.title == "Test Node"
assert domain_model.inputs == inputs_dict
assert domain_model.process_data == process_data_dict
assert domain_model.outputs == outputs_dict
assert domain_model.status == WorkflowNodeExecutionStatus.RUNNING
assert domain_model.error is None
assert domain_model.elapsed_time == 1.5
assert domain_model.metadata == {WorkflowNodeExecutionMetadataKey(k): v for k, v in metadata_dict.items()}
assert domain_model.created_at == now
assert domain_model.finished_at is None
def test_domain_model_without_offload_data(self, db_session_with_containers: Session) -> None:
account = _create_account_with_tenant(db_session_with_containers)
repo = _make_repo(db_session_with_containers, account, str(uuid4()))
process_data = {"normal": "data"}
db_model = WorkflowNodeExecutionModel()
db_model.id = str(uuid4())
db_model.tenant_id = account.current_tenant_id
db_model.app_id = str(uuid4())
db_model.workflow_id = str(uuid4())
db_model.triggered_from = "workflow-run"
db_model.workflow_run_id = None
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_execution_id = str(uuid4())
db_model.node_id = "test-node-id"
db_model.node_type = "llm"
db_model.title = "Test Node"
db_model.inputs = None
db_model.process_data = json.dumps(process_data)
db_model.outputs = None
db_model.status = "succeeded"
db_model.error = None
db_model.elapsed_time = 1.5
db_model.execution_metadata = "{}"
db_model.created_at = datetime.now()
db_model.created_by_role = "account"
db_model.created_by = account.id
db_model.finished_at = None
domain_model = repo._to_domain_model(db_model)
assert domain_model.process_data == process_data
assert domain_model.process_data_truncated is False
assert domain_model.get_truncated_process_data() is None

View File

@ -0,0 +1,255 @@
"""
Integration tests for RagPipelineService methods that interact with the database.
Migrated from unit_tests/services/rag_pipeline/test_rag_pipeline_service.py, replacing
db.session.scalar/commit/delete mocker patches with real PostgreSQL operations.
Covers:
- get_pipeline: Dataset and Pipeline lookups
- update_customized_pipeline_template: find + unique-name check + commit
- delete_customized_pipeline_template: find + delete + commit
"""
from collections.abc import Generator
from types import SimpleNamespace
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session, sessionmaker
from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate
from models.enums import DataSourceType
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
class TestRagPipelineServiceGetPipeline:
"""Integration tests for RagPipelineService.get_pipeline."""
@pytest.fixture(autouse=True)
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
yield
db_session_with_containers.rollback()
def _make_service(self, flask_app_with_containers) -> RagPipelineService:
with (
patch(
"services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository",
return_value=None,
),
patch(
"services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=None,
),
):
session_factory = sessionmaker(bind=flask_app_with_containers.extensions["sqlalchemy"].engine)
return RagPipelineService(session_maker=session_factory)
def _create_pipeline(self, db_session: Session, tenant_id: str, created_by: str) -> Pipeline:
pipeline = Pipeline(
tenant_id=tenant_id,
name=f"Pipeline {uuid4()}",
description="",
created_by=created_by,
)
db_session.add(pipeline)
db_session.flush()
return pipeline
def _create_dataset(
self, db_session: Session, tenant_id: str, created_by: str, pipeline_id: str | None = None
) -> Dataset:
dataset = Dataset(
tenant_id=tenant_id,
name=f"Dataset {uuid4()}",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
pipeline_id=pipeline_id,
)
db_session.add(dataset)
db_session.flush()
return dataset
def test_get_pipeline_raises_when_dataset_not_found(
self, db_session_with_containers: Session, flask_app_with_containers
) -> None:
"""get_pipeline raises ValueError when dataset does not exist."""
service = self._make_service(flask_app_with_containers)
with pytest.raises(ValueError, match="Dataset not found"):
service.get_pipeline(tenant_id=str(uuid4()), dataset_id=str(uuid4()))
def test_get_pipeline_raises_when_pipeline_not_found(
self, db_session_with_containers: Session, flask_app_with_containers
) -> None:
"""get_pipeline raises ValueError when dataset exists but has no linked pipeline."""
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = self._create_dataset(db_session_with_containers, tenant_id, created_by, pipeline_id=None)
db_session_with_containers.flush()
service = self._make_service(flask_app_with_containers)
with pytest.raises(ValueError, match="(Dataset not found|Pipeline not found)"):
service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id)
def test_get_pipeline_returns_pipeline_when_found(
self, db_session_with_containers: Session, flask_app_with_containers
) -> None:
"""get_pipeline returns the Pipeline when both Dataset and Pipeline exist."""
tenant_id = str(uuid4())
created_by = str(uuid4())
pipeline = self._create_pipeline(db_session_with_containers, tenant_id, created_by)
dataset = self._create_dataset(db_session_with_containers, tenant_id, created_by, pipeline_id=pipeline.id)
db_session_with_containers.flush()
service = self._make_service(flask_app_with_containers)
result = service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id)
assert result.id == pipeline.id
class TestUpdateCustomizedPipelineTemplate:
"""Integration tests for RagPipelineService.update_customized_pipeline_template."""
@pytest.fixture(autouse=True)
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
yield
db_session_with_containers.rollback()
def _create_template(
self, db_session: Session, tenant_id: str, created_by: str, name: str = "Template"
) -> PipelineCustomizedTemplate:
template = PipelineCustomizedTemplate(
tenant_id=tenant_id,
name=name,
description="Original description",
chunk_structure="fixed_size",
icon={"type": "emoji", "value": "📄"},
position=1,
yaml_content="{}",
install_count=0,
language="en-US",
created_by=created_by,
)
db_session.add(template)
db_session.flush()
return template
def test_update_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None:
"""update_customized_pipeline_template updates name and description."""
tenant_id = str(uuid4())
created_by = str(uuid4())
template = self._create_template(db_session_with_containers, tenant_id, created_by)
db_session_with_containers.flush()
fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id)
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
info = PipelineTemplateInfoEntity(
name="Updated Name",
description="Updated description",
icon_info=IconInfo(icon="🔥"),
)
result = RagPipelineService.update_customized_pipeline_template(template.id, info)
assert result.name == "Updated Name"
assert result.description == "Updated description"
def test_update_template_raises_when_not_found(
self, db_session_with_containers: Session, flask_app_with_containers
) -> None:
"""update_customized_pipeline_template raises ValueError when template doesn't exist."""
fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4()))
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
info = PipelineTemplateInfoEntity(
name="New Name",
description="desc",
icon_info=IconInfo(icon="📄"),
)
with pytest.raises(ValueError, match="Customized pipeline template not found"):
RagPipelineService.update_customized_pipeline_template(str(uuid4()), info)
def test_update_template_raises_on_duplicate_name(
self, db_session_with_containers: Session, flask_app_with_containers
) -> None:
"""update_customized_pipeline_template raises ValueError when new name already exists."""
tenant_id = str(uuid4())
created_by = str(uuid4())
template1 = self._create_template(db_session_with_containers, tenant_id, created_by, name="Original")
self._create_template(db_session_with_containers, tenant_id, created_by, name="Duplicate")
db_session_with_containers.flush()
fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id)
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
info = PipelineTemplateInfoEntity(
name="Duplicate",
description="desc",
icon_info=IconInfo(icon="📄"),
)
with pytest.raises(ValueError, match="Template name is already exists"):
RagPipelineService.update_customized_pipeline_template(template1.id, info)
class TestDeleteCustomizedPipelineTemplate:
"""Integration tests for RagPipelineService.delete_customized_pipeline_template."""
@pytest.fixture(autouse=True)
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
yield
db_session_with_containers.rollback()
def _create_template(self, db_session: Session, tenant_id: str, created_by: str) -> PipelineCustomizedTemplate:
template = PipelineCustomizedTemplate(
tenant_id=tenant_id,
name=f"Template {uuid4()}",
description="Description",
chunk_structure="fixed_size",
icon={"type": "emoji", "value": "📄"},
position=1,
yaml_content="{}",
install_count=0,
language="en-US",
created_by=created_by,
)
db_session.add(template)
db_session.flush()
return template
def test_delete_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None:
"""delete_customized_pipeline_template removes the template from the DB."""
tenant_id = str(uuid4())
created_by = str(uuid4())
template = self._create_template(db_session_with_containers, tenant_id, created_by)
template_id = template.id
db_session_with_containers.flush()
fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id)
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
RagPipelineService.delete_customized_pipeline_template(template_id)
# Verify the record is deleted within the same context
from sqlalchemy import select
from extensions.ext_database import db as ext_db
remaining = ext_db.session.scalar(
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id)
)
assert remaining is None
def test_delete_template_raises_when_not_found(
self, db_session_with_containers: Session, flask_app_with_containers
) -> None:
"""delete_customized_pipeline_template raises ValueError when template doesn't exist."""
fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4()))
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
with pytest.raises(ValueError, match="Customized pipeline template not found"):
RagPipelineService.delete_customized_pipeline_template(str(uuid4()))

View File

@ -26,7 +26,7 @@ from datetime import timedelta
import pytest
from graphon.entities import WorkflowExecution
from graphon.enums import WorkflowExecutionStatus
from sqlalchemy import delete, select
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from extensions.ext_storage import storage
@ -679,9 +679,12 @@ class TestWorkflowPauseIntegration:
# Verify only 3 were deleted
remaining_count = (
self.session.query(WorkflowPauseModel)
.filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]))
.count()
self.session.scalar(
select(func.count(WorkflowPauseModel.id)).where(
WorkflowPauseModel.id.in_([pe.id for pe in pause_entities])
)
)
or 0
)
assert remaining_count == 2

View File

@ -258,10 +258,10 @@ class TestParentChildIndexProcessor:
session.commit.assert_called_once()
def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
segment_query = Mock()
segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")]
scalars_result = Mock()
scalars_result.all.return_value = [SimpleNamespace(id="seg-1")]
session = Mock()
session.query.return_value = segment_query
session.scalars.return_value = scalars_result
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False

View File

@ -220,10 +220,10 @@ class TestQAIndexProcessor:
self, processor: QAIndexProcessor, dataset: Mock
) -> None:
mock_segment = SimpleNamespace(id="seg-1")
mock_query = Mock()
mock_query.filter.return_value.all.return_value = [mock_segment]
scalars_result = Mock()
scalars_result.all.return_value = [mock_segment]
mock_session = Mock()
mock_session.query.return_value = mock_query
mock_session.scalars.return_value = scalars_result
session_context = MagicMock()
session_context.__enter__.return_value = mock_session
session_context.__exit__.return_value = False

View File

@ -8,7 +8,6 @@ import pytest
from flask import Flask, current_app
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelFeature
from sqlalchemy import column
from core.app.app_config.entities import (
DatasetEntity,
@ -4039,21 +4038,9 @@ class TestDatasetRetrievalAdditionalHelpers:
def test_get_available_datasets(self, retrieval: DatasetRetrieval) -> None:
session = Mock()
subquery_query = Mock()
subquery_query.where.return_value = subquery_query
subquery_query.group_by.return_value = subquery_query
subquery_query.having.return_value = subquery_query
subquery_query.subquery.return_value = SimpleNamespace(
c=SimpleNamespace(
dataset_id=column("dataset_id"), available_document_count=column("available_document_count")
)
)
dataset_query = Mock()
dataset_query.outerjoin.return_value = dataset_query
dataset_query.where.return_value = dataset_query
dataset_query.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")]
session.query.side_effect = [subquery_query, dataset_query]
scalars_result = Mock()
scalars_result.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")]
session.scalars.return_value = scalars_result
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
@ -4902,9 +4889,6 @@ class TestInternalHooksCoverage:
_scalars(segments),
_scalars(bindings),
]
query = Mock()
query.where.return_value = query
session.query.return_value = query
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False
@ -4919,7 +4903,7 @@ class TestInternalHooksCoverage:
):
retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1})
query.update.assert_called_once()
session.execute.assert_called_once()
mock_trace.assert_called_once()
def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None:

View File

@ -260,6 +260,28 @@ def test_agent_invoke_engine_meta_error():
assert error_meta.error == "meta failure"
def test_convert_tool_response_excludes_variable_messages():
"""Regression test for issue #34723.
WorkflowTool._invoke yields VARIABLE, TEXT, and suppressed-JSON messages.
_convert_tool_response_to_str must skip VARIABLE messages so that the
returned string contains only the TEXT representation and not a
duplicated, garbled Pydantic repr of the same data.
"""
tool = _build_tool()
outputs = {"reports": "hello"}
messages = [
tool.create_variable_message(variable_name="reports", variable_value="hello"),
tool.create_text_message('{"reports": "hello"}'),
tool.create_json_message(outputs, suppress_output=True),
]
result = ToolEngine._convert_tool_response_to_str(messages)
assert result == '{"reports": "hello"}'
assert "variable_name" not in result
def test_agent_invoke_tool_invoke_error():
tool = _build_tool(with_llm_parameter=True)
callback = Mock()

View File

@ -637,7 +637,7 @@ def test_list_default_builtin_providers_for_postgres_and_mysql():
for scheme in ("postgresql", "mysql"):
session = Mock()
session.execute.return_value.all.return_value = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")]
session.query.return_value.where.return_value.all.return_value = provider_records
session.scalars.return_value = iter(provider_records)
with patch("core.tools.tool_manager.dify_config", SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME=scheme)):
with patch("core.tools.tool_manager.db") as mock_db:

View File

@ -291,24 +291,6 @@ class TestAppModelConfig:
# Assert
assert result == questions
def test_app_model_config_annotation_reply_dict_disabled(self):
"""Test annotation_reply_dict when annotation is disabled."""
# Arrange
config = AppModelConfig(
app_id=str(uuid4()),
provider="openai",
model_id="gpt-4",
created_by=str(uuid4()),
)
# Mock database scalar to return None (no annotation setting found)
with patch("models.model.db.session.scalar", return_value=None):
# Act
result = config.annotation_reply_dict
# Assert
assert result == {"enabled": False}
class TestConversationModel:
"""Test suite for Conversation model integrity."""
@ -948,17 +930,6 @@ class TestSiteModel:
with pytest.raises(ValueError, match="Custom disclaimer cannot exceed 512 characters"):
site.custom_disclaimer = long_disclaimer
def test_site_generate_code(self):
"""Test Site.generate_code static method."""
# Mock database scalar to return 0 (no existing codes)
with patch("models.model.db.session.scalar", return_value=0):
# Act
code = Site.generate_code(8)
# Assert
assert isinstance(code, str)
assert len(code) == 8
class TestModelIntegration:
"""Test suite for model integration scenarios."""
@ -1146,314 +1117,3 @@ class TestModelIntegration:
# Assert
assert site.app_id == app.id
assert app.enable_site is True
class TestConversationStatusCount:
"""Test suite for Conversation.status_count property N+1 query fix."""
def test_status_count_no_messages(self):
"""Test status_count returns None when conversation has no messages."""
# Arrange
conversation = Conversation(
app_id=str(uuid4()),
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source=ConversationFromSource.API,
)
conversation.id = str(uuid4())
# Mock the database query to return no messages
with patch("models.model.db.session.scalars") as mock_scalars:
mock_scalars.return_value.all.return_value = []
# Act
result = conversation.status_count
# Assert
assert result is None
def test_status_count_messages_without_workflow_runs(self):
"""Test status_count when messages have no workflow_run_id."""
# Arrange
app_id = str(uuid4())
conversation_id = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id
# Mock the database query to return no messages with workflow_run_id
with patch("models.model.db.session.scalars") as mock_scalars:
mock_scalars.return_value.all.return_value = []
# Act
result = conversation.status_count
# Assert
assert result is None
def test_status_count_batch_loading_implementation(self):
"""Test that status_count uses batch loading instead of N+1 queries."""
# Arrange
from graphon.enums import WorkflowExecutionStatus
app_id = str(uuid4())
conversation_id = str(uuid4())
# Create workflow run IDs
workflow_run_id_1 = str(uuid4())
workflow_run_id_2 = str(uuid4())
workflow_run_id_3 = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id
# Mock messages with workflow_run_id
mock_messages = [
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id_1,
),
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id_2,
),
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id_3,
),
]
# Mock workflow runs with different statuses
mock_workflow_runs = [
MagicMock(
id=workflow_run_id_1,
status=WorkflowExecutionStatus.SUCCEEDED.value,
app_id=app_id,
),
MagicMock(
id=workflow_run_id_2,
status=WorkflowExecutionStatus.FAILED.value,
app_id=app_id,
),
MagicMock(
id=workflow_run_id_3,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value,
app_id=app_id,
),
]
# Track database calls
calls_made = []
def mock_scalars(query):
calls_made.append(str(query))
mock_result = MagicMock()
# Return messages for the first query (messages with workflow_run_id)
if "messages" in str(query) and "conversation_id" in str(query):
mock_result.all.return_value = mock_messages
# Return workflow runs for the batch query
elif "workflow_runs" in str(query):
mock_result.all.return_value = mock_workflow_runs
else:
mock_result.all.return_value = []
return mock_result
# Act & Assert
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
result = conversation.status_count
# Verify only 2 database queries were made (not N+1)
assert len(calls_made) == 2, f"Expected 2 queries, got {len(calls_made)}: {calls_made}"
# Verify the first query gets messages
assert "messages" in calls_made[0]
assert "conversation_id" in calls_made[0]
# Verify the second query batch loads workflow runs with proper filtering
assert "workflow_runs" in calls_made[1]
assert "app_id" in calls_made[1] # Security filter applied
assert "IN" in calls_made[1] # Batch loading with IN clause
# Verify correct status counts
assert result["success"] == 1 # One SUCCEEDED
assert result["failed"] == 1 # One FAILED
assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED
assert result["paused"] == 0
def test_status_count_app_id_filtering(self):
"""Test that status_count filters workflow runs by app_id for security."""
# Arrange
app_id = str(uuid4())
other_app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_run_id = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id
# Mock message with workflow_run_id
mock_messages = [
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id,
),
]
calls_made = []
def mock_scalars(query):
calls_made.append(str(query))
mock_result = MagicMock()
if "messages" in str(query):
mock_result.all.return_value = mock_messages
elif "workflow_runs" in str(query):
# Return empty list because no workflow run matches the correct app_id
mock_result.all.return_value = [] # Workflow run filtered out by app_id
else:
mock_result.all.return_value = []
return mock_result
# Act
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
result = conversation.status_count
# Assert - query should include app_id filter
workflow_query = calls_made[1]
assert "app_id" in workflow_query
# Since workflow run has wrong app_id, it shouldn't be included in counts
assert result["success"] == 0
assert result["failed"] == 0
assert result["partial_success"] == 0
assert result["paused"] == 0
def test_status_count_handles_invalid_workflow_status(self):
"""Test that status_count gracefully handles invalid workflow status values."""
# Arrange
app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_run_id = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id
mock_messages = [
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id,
),
]
# Mock workflow run with invalid status
mock_workflow_runs = [
MagicMock(
id=workflow_run_id,
status="invalid_status", # Invalid status that should raise ValueError
app_id=app_id,
),
]
with patch("models.model.db.session.scalars") as mock_scalars:
# Mock the messages query
def mock_scalars_side_effect(query):
mock_result = MagicMock()
if "messages" in str(query):
mock_result.all.return_value = mock_messages
elif "workflow_runs" in str(query):
mock_result.all.return_value = mock_workflow_runs
else:
mock_result.all.return_value = []
return mock_result
mock_scalars.side_effect = mock_scalars_side_effect
# Act - should not raise exception
result = conversation.status_count
# Assert - should handle invalid status gracefully
assert result["success"] == 0
assert result["failed"] == 0
assert result["partial_success"] == 0
assert result["paused"] == 0
def test_status_count_paused(self):
"""Test status_count includes paused workflow runs."""
# Arrange
from graphon.enums import WorkflowExecutionStatus
app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_run_id = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id
mock_messages = [
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id,
),
]
mock_workflow_runs = [
MagicMock(
id=workflow_run_id,
status=WorkflowExecutionStatus.PAUSED.value,
app_id=app_id,
),
]
with patch("models.model.db.session.scalars") as mock_scalars:
def mock_scalars_side_effect(query):
mock_result = MagicMock()
if "messages" in str(query):
mock_result.all.return_value = mock_messages
elif "workflow_runs" in str(query):
mock_result.all.return_value = mock_workflow_runs
else:
mock_result.all.return_value = []
return mock_result
mock_scalars.side_effect = mock_scalars_side_effect
# Act
result = conversation.status_count
# Assert
assert result["paused"] == 1

View File

@ -101,118 +101,6 @@ def _build_local_file_mapping(record_id: str, *, tenant_id: str | None = None) -
return mapping
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
def test_inputs_resolve_owner_tenant_for_single_file_mapping(
monkeypatch: pytest.MonkeyPatch,
owner_cls: type[Conversation] | type[Message],
):
model_module = importlib.import_module("models.model")
build_calls: list[tuple[dict[str, object], str]] = []
monkeypatch.setattr(model_module.db.session, "scalar", lambda _: "tenant-from-app")
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller):
_ = config, strict_type_validation, access_controller
build_calls.append((dict(mapping), tenant_id))
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping)
owner = owner_cls(app_id="app-1")
owner.inputs = {"file": _build_local_file_mapping("upload-1")}
restored_inputs = owner.inputs
assert restored_inputs["file"] == {"tenant_id": "tenant-from-app", "upload_file_id": "upload-1"}
assert build_calls == [
(
{
**_build_local_file_mapping("upload-1"),
"upload_file_id": "upload-1",
},
"tenant-from-app",
)
]
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
def test_inputs_resolve_owner_tenant_for_file_list_mapping(
monkeypatch: pytest.MonkeyPatch,
owner_cls: type[Conversation] | type[Message],
):
model_module = importlib.import_module("models.model")
build_calls: list[tuple[dict[str, object], str]] = []
monkeypatch.setattr(model_module.db.session, "scalar", lambda _: "tenant-from-app")
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller):
_ = config, strict_type_validation, access_controller
build_calls.append((dict(mapping), tenant_id))
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping)
owner = owner_cls(app_id="app-1")
owner.inputs = {
"files": [
_build_local_file_mapping("upload-1"),
_build_local_file_mapping("upload-2"),
]
}
restored_inputs = owner.inputs
assert restored_inputs["files"] == [
{"tenant_id": "tenant-from-app", "upload_file_id": "upload-1"},
{"tenant_id": "tenant-from-app", "upload_file_id": "upload-2"},
]
assert build_calls == [
(
{
**_build_local_file_mapping("upload-1"),
"upload_file_id": "upload-1",
},
"tenant-from-app",
),
(
{
**_build_local_file_mapping("upload-2"),
"upload_file_id": "upload-2",
},
"tenant-from-app",
),
]
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
def test_inputs_prefer_serialized_tenant_id_when_present(
monkeypatch: pytest.MonkeyPatch,
owner_cls: type[Conversation] | type[Message],
):
model_module = importlib.import_module("models.model")
def fail_if_called(_):
raise AssertionError("App tenant lookup should not run when tenant_id exists in the file mapping")
monkeypatch.setattr(model_module.db.session, "scalar", fail_if_called)
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller):
_ = config, strict_type_validation, access_controller
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping)
owner = owner_cls(app_id="app-1")
owner.inputs = {"file": _build_local_file_mapping("upload-1", tenant_id="tenant-from-payload")}
restored_inputs = owner.inputs
assert restored_inputs["file"] == {
"tenant_id": "tenant-from-payload",
"upload_file_id": "upload-1",
}
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
def test_inputs_restore_external_remote_url_file_mappings(owner_cls: type[Conversation] | type[Message]) -> None:
owner = owner_cls(app_id="app-1")

View File

@ -1,188 +0,0 @@
import types
import pytest
from models.engine import db
from models.enums import CreatorUserRole
from models.workflow import WorkflowNodeExecutionModel
@pytest.fixture
def fake_db_scalar(monkeypatch):
"""Provide a controllable fake for db.session.scalar (SQLAlchemy 2.0 style)."""
calls = []
def _install(side_effect):
def _fake_scalar(statement):
calls.append(statement)
return side_effect(statement)
# Patch the modern API used by the model implementation
monkeypatch.setattr(db.session, "scalar", _fake_scalar)
# Backward-compatibility: if the implementation still uses db.session.get,
# make it delegate to the same side_effect so tests remain valid on older code.
if hasattr(db.session, "get"):
def _fake_get(*_args, **_kwargs):
return side_effect(None)
monkeypatch.setattr(db.session, "get", _fake_get)
return calls
return _install
def make_account(id_: str = "acc-1"):
# Use a simple object to avoid constructing a full SQLAlchemy model instance
# Python 3.12 forbids reassigning __class__ for SimpleNamespace; not needed here.
obj = types.SimpleNamespace()
obj.id = id_
return obj
def make_end_user(id_: str = "user-1"):
# Lightweight stand-in object; no need to spoof class identity.
obj = types.SimpleNamespace()
obj.id = id_
return obj
def test_created_by_account_returns_account_when_role_account(fake_db_scalar):
account = make_account("acc-1")
# The implementation uses db.session.scalar(select(Account)...). We only need to
# return the expected object when called; the exact SQL is irrelevant for this unit test.
def side_effect(_statement):
return account
fake_db_scalar(side_effect)
log = WorkflowNodeExecutionModel(
tenant_id="t1",
app_id="a1",
workflow_id="w1",
triggered_from="workflow-run",
workflow_run_id=None,
index=1,
predecessor_node_id=None,
node_execution_id=None,
node_id="n1",
node_type="start",
title="Start",
inputs=None,
process_data=None,
outputs=None,
status="succeeded",
error=None,
elapsed_time=0.0,
execution_metadata=None,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by="acc-1",
)
assert log.created_by_account is account
def test_created_by_account_returns_none_when_role_not_account(fake_db_scalar):
# Even if an Account with matching id exists, property should return None when role is END_USER
account = make_account("acc-1")
def side_effect(_statement):
return account
fake_db_scalar(side_effect)
log = WorkflowNodeExecutionModel(
tenant_id="t1",
app_id="a1",
workflow_id="w1",
triggered_from="workflow-run",
workflow_run_id=None,
index=1,
predecessor_node_id=None,
node_execution_id=None,
node_id="n1",
node_type="start",
title="Start",
inputs=None,
process_data=None,
outputs=None,
status="succeeded",
error=None,
elapsed_time=0.0,
execution_metadata=None,
created_by_role=CreatorUserRole.END_USER.value,
created_by="acc-1",
)
assert log.created_by_account is None
def test_created_by_end_user_returns_end_user_when_role_end_user(fake_db_scalar):
end_user = make_end_user("user-1")
def side_effect(_statement):
return end_user
fake_db_scalar(side_effect)
log = WorkflowNodeExecutionModel(
tenant_id="t1",
app_id="a1",
workflow_id="w1",
triggered_from="workflow-run",
workflow_run_id=None,
index=1,
predecessor_node_id=None,
node_execution_id=None,
node_id="n1",
node_type="start",
title="Start",
inputs=None,
process_data=None,
outputs=None,
status="succeeded",
error=None,
elapsed_time=0.0,
execution_metadata=None,
created_by_role=CreatorUserRole.END_USER.value,
created_by="user-1",
)
assert log.created_by_end_user is end_user
def test_created_by_end_user_returns_none_when_role_not_end_user(fake_db_scalar):
end_user = make_end_user("user-1")
def side_effect(_statement):
return end_user
fake_db_scalar(side_effect)
log = WorkflowNodeExecutionModel(
tenant_id="t1",
app_id="a1",
workflow_id="w1",
triggered_from="workflow-run",
workflow_run_id=None,
index=1,
predecessor_node_id=None,
node_execution_id=None,
node_id="n1",
node_type="start",
title="Start",
inputs=None,
process_data=None,
outputs=None,
status="succeeded",
error=None,
elapsed_time=0.0,
execution_metadata=None,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by="user-1",
)
assert log.created_by_end_user is None

View File

@ -1,3 +0,0 @@
"""
Unit tests for workflow_node_execution repositories.
"""

View File

@ -1,340 +0,0 @@
"""
Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
"""
import json
import uuid
from datetime import datetime
from decimal import Decimal
from unittest.mock import MagicMock, PropertyMock
import pytest
from graphon.entities import (
WorkflowNodeExecution,
)
from graphon.enums import (
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pytest_mock import MockerFixture
from sqlalchemy.orm import Session, sessionmaker
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.factory import OrderConfig
from models.account import Account, Tenant
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
def configure_mock_execution(mock_execution):
"""Configure a mock execution with proper JSON serializable values."""
# Configure inputs, outputs, process_data, and execution_metadata to return JSON serializable values
type(mock_execution).inputs = PropertyMock(return_value='{"key": "value"}')
type(mock_execution).outputs = PropertyMock(return_value='{"result": "success"}')
type(mock_execution).process_data = PropertyMock(return_value='{"process": "data"}')
type(mock_execution).execution_metadata = PropertyMock(return_value='{"metadata": "info"}')
# Configure status and triggered_from to be valid enum values
mock_execution.status = "running"
mock_execution.triggered_from = "workflow-run"
return mock_execution
@pytest.fixture
def session():
"""Create a mock SQLAlchemy session."""
session = MagicMock(spec=Session)
# Configure the session to be used as a context manager
session.__enter__ = MagicMock(return_value=session)
session.__exit__ = MagicMock(return_value=None)
# Configure the session factory to return the session
session_factory = MagicMock(spec=sessionmaker)
session_factory.return_value = session
return session, session_factory
@pytest.fixture
def mock_user():
"""Create a user instance for testing."""
user = Account(name="test", email="test@example.com")
user.id = "test-user-id"
tenant = Tenant(name="Test Workspace")
tenant.id = "test-tenant"
user._current_tenant = MagicMock()
user._current_tenant.id = "test-tenant"
return user
@pytest.fixture
def repository(session, mock_user):
"""Create a repository instance with test data."""
_, session_factory = session
app_id = "test-app"
return SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=mock_user,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_save(repository, session):
"""Test save method."""
session_obj, _ = session
# Create a mock execution
execution = MagicMock(spec=WorkflowNodeExecution)
execution.id = "test-id"
execution.node_execution_id = "test-node-execution-id"
execution.tenant_id = None
execution.app_id = None
execution.inputs = None
execution.process_data = None
execution.outputs = None
execution.metadata = None
execution.workflow_id = str(uuid.uuid4())
# Mock the to_db_model method to return the execution itself
# This simulates the behavior of setting tenant_id and app_id
db_model = MagicMock(spec=WorkflowNodeExecutionModel)
db_model.id = "test-id"
db_model.node_execution_id = "test-node-execution-id"
repository._to_db_model = MagicMock(return_value=db_model)
# Mock session.get to return None (no existing record)
session_obj.get.return_value = None
# Call save method
repository.save(execution)
# Assert to_db_model was called with the execution
repository._to_db_model.assert_called_once_with(execution)
# Assert session.get was called to check for existing record
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, db_model.id)
# Assert session.add was called for new record
session_obj.add.assert_called_once_with(db_model)
# Assert session.commit was called
session_obj.commit.assert_called_once()
def test_save_with_existing_tenant_id(repository, session):
"""Test save method with existing tenant_id."""
session_obj, _ = session
# Create a mock execution with existing tenant_id
execution = MagicMock(spec=WorkflowNodeExecutionModel)
execution.id = "existing-id"
execution.node_execution_id = "existing-node-execution-id"
execution.tenant_id = "existing-tenant"
execution.app_id = None
execution.inputs = None
execution.process_data = None
execution.outputs = None
execution.metadata = None
# Create a modified execution that will be returned by _to_db_model
modified_execution = MagicMock(spec=WorkflowNodeExecutionModel)
modified_execution.id = "existing-id"
modified_execution.node_execution_id = "existing-node-execution-id"
modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change
modified_execution.app_id = repository._app_id # App ID should be set
# Create a dictionary to simulate __dict__ for updating attributes
modified_execution.__dict__ = {
"id": "existing-id",
"node_execution_id": "existing-node-execution-id",
"tenant_id": "existing-tenant",
"app_id": repository._app_id,
}
# Mock the to_db_model method to return the modified execution
repository._to_db_model = MagicMock(return_value=modified_execution)
# Mock session.get to return an existing record
existing_model = MagicMock(spec=WorkflowNodeExecutionModel)
session_obj.get.return_value = existing_model
# Call save method
repository.save(execution)
# Assert to_db_model was called with the execution
repository._to_db_model.assert_called_once_with(execution)
# Assert session.get was called to check for existing record
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, modified_execution.id)
# Assert session.add was NOT called since we're updating existing
session_obj.add.assert_not_called()
# Assert session.commit was called
session_obj.commit.assert_called_once()
def test_get_by_workflow_execution(repository, session, mocker: MockerFixture):
"""Test get_by_workflow_execution method."""
session_obj, _ = session
# Set up mock
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
mock_asc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.asc")
mock_desc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.desc")
mock_WorkflowNodeExecutionModel = mocker.patch(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel"
)
mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
mock_stmt.order_by.return_value = mock_stmt
mock_asc.return_value = mock_stmt
mock_desc.return_value = mock_stmt
mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.return_value = mock_stmt
# Create a properly configured mock execution
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
configure_mock_execution(mock_execution)
session_obj.scalars.return_value.all.return_value = [mock_execution]
# Create a mock domain model to be returned by _to_domain_model
mock_domain_model = mocker.MagicMock()
# Mock the _to_domain_model method to return our mock domain model
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
# Call method
order_config = OrderConfig(order_by=["index"], order_direction="desc")
result = repository.get_by_workflow_execution(
workflow_execution_id="test-workflow-run-id",
order_config=order_config,
)
# Assert select was called with correct parameters
mock_select.assert_called_once()
session_obj.scalars.assert_called_once_with(mock_stmt)
mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.assert_called_once_with(mock_stmt)
# Assert _to_domain_model was called with the mock execution
repository._to_domain_model.assert_called_once_with(mock_execution)
# Assert the result contains our mock domain model
assert len(result) == 1
assert result[0] is mock_domain_model
def test_to_db_model(repository):
"""Test to_db_model method."""
# Create a domain model
domain_model = WorkflowNodeExecution(
id="test-id",
workflow_id="test-workflow-id",
node_execution_id="test-node-execution-id",
workflow_execution_id="test-workflow-run-id",
index=1,
predecessor_node_id="test-predecessor-id",
node_id="test-node-id",
node_type=BuiltinNodeTypes.START,
title="Test Node",
inputs={"input_key": "input_value"},
process_data={"process_key": "process_value"},
outputs={"output_key": "output_value"},
status=WorkflowNodeExecutionStatus.RUNNING,
error=None,
elapsed_time=1.5,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: Decimal("0.0"),
},
created_at=datetime.now(),
finished_at=None,
)
# Convert to DB model
db_model = repository._to_db_model(domain_model)
# Assert DB model has correct values
assert isinstance(db_model, WorkflowNodeExecutionModel)
assert db_model.id == domain_model.id
assert db_model.tenant_id == repository._tenant_id
assert db_model.app_id == repository._app_id
assert db_model.workflow_id == domain_model.workflow_id
assert db_model.triggered_from == repository._triggered_from
assert db_model.workflow_run_id == domain_model.workflow_execution_id
assert db_model.index == domain_model.index
assert db_model.predecessor_node_id == domain_model.predecessor_node_id
assert db_model.node_execution_id == domain_model.node_execution_id
assert db_model.node_id == domain_model.node_id
assert db_model.node_type == domain_model.node_type
assert db_model.title == domain_model.title
assert db_model.inputs_dict == domain_model.inputs
assert db_model.process_data_dict == domain_model.process_data
assert db_model.outputs_dict == domain_model.outputs
assert db_model.execution_metadata_dict == jsonable_encoder(domain_model.metadata)
assert db_model.status == domain_model.status
assert db_model.error == domain_model.error
assert db_model.elapsed_time == domain_model.elapsed_time
assert db_model.created_at == domain_model.created_at
assert db_model.created_by_role == repository._creator_user_role
assert db_model.created_by == repository._creator_user_id
assert db_model.finished_at == domain_model.finished_at
def test_to_domain_model(repository):
"""Test _to_domain_model method."""
# Create input dictionaries
inputs_dict = {"input_key": "input_value"}
process_data_dict = {"process_key": "process_value"}
outputs_dict = {"output_key": "output_value"}
metadata_dict = {str(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS): 100}
# Create a DB model using our custom subclass
db_model = WorkflowNodeExecutionModel()
db_model.id = "test-id"
db_model.tenant_id = "test-tenant-id"
db_model.app_id = "test-app-id"
db_model.workflow_id = "test-workflow-id"
db_model.triggered_from = "workflow-run"
db_model.workflow_run_id = "test-workflow-run-id"
db_model.index = 1
db_model.predecessor_node_id = "test-predecessor-id"
db_model.node_execution_id = "test-node-execution-id"
db_model.node_id = "test-node-id"
db_model.node_type = BuiltinNodeTypes.START
db_model.title = "Test Node"
db_model.inputs = json.dumps(inputs_dict)
db_model.process_data = json.dumps(process_data_dict)
db_model.outputs = json.dumps(outputs_dict)
db_model.status = WorkflowNodeExecutionStatus.RUNNING
db_model.error = None
db_model.elapsed_time = 1.5
db_model.execution_metadata = json.dumps(metadata_dict)
db_model.created_at = datetime.now()
db_model.created_by_role = "account"
db_model.created_by = "test-user-id"
db_model.finished_at = None
# Convert to domain model
domain_model = repository._to_domain_model(db_model)
# Assert domain model has correct values
assert isinstance(domain_model, WorkflowNodeExecution)
assert domain_model.id == db_model.id
assert domain_model.workflow_id == db_model.workflow_id
assert domain_model.workflow_execution_id == db_model.workflow_run_id
assert domain_model.index == db_model.index
assert domain_model.predecessor_node_id == db_model.predecessor_node_id
assert domain_model.node_execution_id == db_model.node_execution_id
assert domain_model.node_id == db_model.node_id
assert domain_model.node_type == db_model.node_type
assert domain_model.title == db_model.title
assert domain_model.inputs == inputs_dict
assert domain_model.process_data == process_data_dict
assert domain_model.outputs == outputs_dict
assert domain_model.status == WorkflowNodeExecutionStatus(db_model.status)
assert domain_model.error == db_model.error
assert domain_model.elapsed_time == db_model.elapsed_time
assert domain_model.metadata == metadata_dict
assert domain_model.created_at == db_model.created_at
assert domain_model.finished_at == db_model.finished_at

View File

@ -1,106 +0,0 @@
"""
Unit tests for SQLAlchemyWorkflowNodeExecutionRepository, focusing on process_data truncation functionality.
"""
from datetime import datetime
from typing import Any
from unittest.mock import MagicMock, Mock
from graphon.entities import WorkflowNodeExecution
from graphon.enums import BuiltinNodeTypes
from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
SQLAlchemyWorkflowNodeExecutionRepository,
)
from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData:
"""Test process_data truncation functionality in SQLAlchemyWorkflowNodeExecutionRepository."""
def create_mock_account(self) -> Account:
"""Create a mock Account for testing."""
account = Mock(spec=Account)
account.id = "test-user-id"
account.tenant_id = "test-tenant-id"
return account
def create_mock_session_factory(self) -> sessionmaker:
"""Create a mock session factory for testing."""
mock_session = MagicMock()
mock_session_factory = MagicMock(spec=sessionmaker)
mock_session_factory.return_value.__enter__.return_value = mock_session
mock_session_factory.return_value.__exit__.return_value = None
return mock_session_factory
def create_repository(self, mock_file_service=None) -> SQLAlchemyWorkflowNodeExecutionRepository:
"""Create a repository instance for testing."""
mock_account = self.create_mock_account()
mock_session_factory = self.create_mock_session_factory()
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
if mock_file_service:
repository._file_service = mock_file_service
return repository
def create_workflow_node_execution(
self,
process_data: dict[str, Any] | None = None,
execution_id: str = "test-execution-id",
) -> WorkflowNodeExecution:
"""Create a WorkflowNodeExecution instance for testing."""
return WorkflowNodeExecution(
id=execution_id,
workflow_id="test-workflow-id",
index=1,
node_id="test-node-id",
node_type=BuiltinNodeTypes.LLM,
title="Test Node",
process_data=process_data,
created_at=datetime.now(),
)
def test_to_domain_model_without_offload_data(self):
"""Test _to_domain_model without offload data."""
repository = self.create_repository()
# Create mock database model without offload data
db_model = Mock(spec=WorkflowNodeExecutionModel)
db_model.id = "test-execution-id"
db_model.node_execution_id = "test-node-execution-id"
db_model.workflow_id = "test-workflow-id"
db_model.workflow_run_id = None
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "test-node-id"
db_model.node_type = "llm"
db_model.title = "Test Node"
db_model.status = "succeeded"
db_model.error = None
db_model.elapsed_time = 1.5
db_model.created_at = datetime.now()
db_model.finished_at = None
process_data = {"normal": "data"}
db_model.process_data_dict = process_data
db_model.inputs_dict = None
db_model.outputs_dict = None
db_model.execution_metadata_dict = {}
db_model.offload_data = None
domain_model = repository._to_domain_model(db_model)
# Domain model should have the data from database
assert domain_model.process_data == process_data
# Should not be truncated
assert domain_model.process_data_truncated is False
assert domain_model.get_truncated_process_data() is None

File diff suppressed because it is too large Load Diff

View File

@ -1,825 +0,0 @@
"""
Comprehensive unit tests for DatasetPermissionService and DatasetService permission methods.
This module contains extensive unit tests for dataset permission management,
including partial member list operations, permission validation, and permission
enum handling.
The DatasetPermissionService provides methods for:
- Retrieving partial member permissions (get_dataset_partial_member_list)
- Updating partial member lists (update_partial_member_list)
- Validating permissions before operations (check_permission)
- Clearing partial member lists (clear_partial_member_list)
The DatasetService provides permission checking methods:
- check_dataset_permission - validates user access to dataset
- check_dataset_operator_permission - validates operator permissions
These operations are critical for dataset access control and security, ensuring
that users can only access datasets they have permission to view or modify.
This test suite ensures:
- Correct retrieval of partial member lists
- Proper update of partial member permissions
- Accurate permission validation logic
- Proper handling of permission enums (only_me, all_team_members, partial_members)
- Security boundaries are maintained
- Error conditions are handled correctly
================================================================================
ARCHITECTURE OVERVIEW
================================================================================
The Dataset permission system is a multi-layered access control mechanism
that provides fine-grained control over who can access and modify datasets.
1. Permission Levels:
- only_me: Only the dataset creator can access
- all_team_members: All members of the tenant can access
- partial_members: Only specific users listed in DatasetPermission can access
2. Permission Storage:
- Dataset.permission: Stores the permission level enum
- DatasetPermission: Stores individual user permissions for partial_members
- Each DatasetPermission record links a dataset to a user account
3. Permission Validation:
- Tenant-level checks: Users must be in the same tenant
- Role-based checks: OWNER role bypasses some restrictions
- Explicit permission checks: For partial_members, explicit DatasetPermission
records are required
4. Permission Operations:
- Partial member list management: Add/remove users from partial access
- Permission validation: Check before allowing operations
- Permission clearing: Remove all partial members when changing permission level
================================================================================
TESTING STRATEGY
================================================================================
This test suite follows a comprehensive testing strategy that covers:
1. Partial Member List Operations:
- Retrieving member lists
- Adding new members
- Updating existing members
- Removing members
- Empty list handling
2. Permission Validation:
- Dataset editor permissions
- Dataset operator restrictions
- Permission enum validation
- Partial member list validation
- Tenant isolation
3. Permission Enum Handling:
- only_me permission behavior
- all_team_members permission behavior
- partial_members permission behavior
- Permission transitions
- Edge cases for each enum value
4. Security and Access Control:
- Tenant boundary enforcement
- Role-based access control
- Creator privilege validation
- Explicit permission requirement
5. Error Handling:
- Invalid permission changes
- Missing required data
- Database transaction failures
- Permission denial scenarios
================================================================================
"""
from unittest.mock import Mock, create_autospec, patch
import pytest
from models import Account, TenantAccountRole
from models.dataset import (
Dataset,
DatasetPermission,
DatasetPermissionEnum,
)
from services.dataset_service import DatasetPermissionService, DatasetService
from services.errors.account import NoPermissionError
# ============================================================================
# Test Data Factory
# ============================================================================
# The Test Data Factory pattern is used here to centralize the creation of
# test objects and mock instances. This approach provides several benefits:
#
# 1. Consistency: All test objects are created using the same factory methods,
# ensuring consistent structure across all tests.
#
# 2. Maintainability: If the structure of models or services changes, we only
# need to update the factory methods rather than every individual test.
#
# 3. Reusability: Factory methods can be reused across multiple test classes,
# reducing code duplication.
#
# 4. Readability: Tests become more readable when they use descriptive factory
# method calls instead of complex object construction logic.
#
# ============================================================================
class DatasetPermissionTestDataFactory:
"""
Factory class for creating test data and mock objects for dataset permission tests.
This factory provides static methods to create mock objects for:
- Dataset instances with various permission configurations
- User/Account instances with different roles and permissions
- DatasetPermission instances
- Permission enum values
- Database query results
The factory methods help maintain consistency across tests and reduce
code duplication when setting up test scenarios.
"""
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
created_by: str = "user-123",
name: str = "Test Dataset",
**kwargs,
) -> Mock:
"""
Create a mock Dataset with specified attributes.
Args:
dataset_id: Unique identifier for the dataset
tenant_id: Tenant identifier
permission: Permission level enum
created_by: ID of user who created the dataset
name: Dataset name
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as a Dataset instance
"""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.permission = permission
dataset.created_by = created_by
dataset.name = name
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_user_mock(
user_id: str = "user-123",
tenant_id: str = "tenant-123",
role: TenantAccountRole = TenantAccountRole.NORMAL,
is_dataset_editor: bool = True,
is_dataset_operator: bool = False,
**kwargs,
) -> Mock:
"""
Create a mock user (Account) with specified attributes.
Args:
user_id: Unique identifier for the user
tenant_id: Tenant identifier
role: User role (OWNER, ADMIN, NORMAL, DATASET_OPERATOR, etc.)
is_dataset_editor: Whether user has dataset editor permissions
is_dataset_operator: Whether user is a dataset operator
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as an Account instance
"""
user = create_autospec(Account, instance=True)
user.id = user_id
user.current_tenant_id = tenant_id
user.current_role = role
user.is_dataset_editor = is_dataset_editor
user.is_dataset_operator = is_dataset_operator
for key, value in kwargs.items():
setattr(user, key, value)
return user
@staticmethod
def create_dataset_permission_mock(
permission_id: str = "permission-123",
dataset_id: str = "dataset-123",
account_id: str = "user-456",
tenant_id: str = "tenant-123",
has_permission: bool = True,
**kwargs,
) -> Mock:
"""
Create a mock DatasetPermission instance.
Args:
permission_id: Unique identifier for the permission
dataset_id: Dataset ID
account_id: User account ID
tenant_id: Tenant identifier
has_permission: Whether permission is granted
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as a DatasetPermission instance
"""
permission = Mock(spec=DatasetPermission)
permission.id = permission_id
permission.dataset_id = dataset_id
permission.account_id = account_id
permission.tenant_id = tenant_id
permission.has_permission = has_permission
for key, value in kwargs.items():
setattr(permission, key, value)
return permission
@staticmethod
def create_user_list_mock(user_ids: list[str]) -> list[dict[str, str]]:
"""
Create a list of user dictionaries for partial member list operations.
Args:
user_ids: List of user IDs to include
Returns:
List of user dictionaries with "user_id" keys
"""
return [{"user_id": user_id} for user_id in user_ids]
# ============================================================================
# Tests for check_permission
# ============================================================================
class TestDatasetPermissionServiceCheckPermission:
"""
Comprehensive unit tests for DatasetPermissionService.check_permission method.
This test class covers the permission validation logic that ensures
users have the appropriate permissions to modify dataset permissions.
The check_permission method:
1. Validates user is a dataset editor
2. Checks if dataset operator is trying to change permissions
3. Validates partial member list when setting to partial_members
4. Ensures dataset operators cannot change permission levels
5. Ensures dataset operators cannot modify partial member lists
Test scenarios include:
- Valid permission changes by dataset editors
- Dataset operator restrictions
- Partial member list validation
- Missing dataset editor permissions
- Invalid permission changes
"""
@pytest.fixture
def mock_get_partial_member_list(self):
"""
Mock get_dataset_partial_member_list method.
Provides a mocked version of the get_dataset_partial_member_list
method for testing permission validation logic.
"""
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list") as mock_get_list:
yield mock_get_list
def test_check_permission_dataset_editor_success(self, mock_get_partial_member_list):
"""
Test successful permission check for dataset editor.
Verifies that when a dataset editor (not operator) tries to
change permissions, the check passes.
This test ensures:
- Dataset editors can change permissions
- No errors are raised for valid changes
- Partial member list validation is skipped for non-operators
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=False)
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME)
requested_permission = DatasetPermissionEnum.ALL_TEAM
requested_partial_member_list = None
# Act (should not raise)
DatasetPermissionService.check_permission(user, dataset, requested_permission, requested_partial_member_list)
# Assert
# Verify get_partial_member_list was not called (not needed for non-operators)
mock_get_partial_member_list.assert_not_called()
def test_check_permission_not_dataset_editor_error(self):
"""
Test error when user is not a dataset editor.
Verifies that when a user without dataset editor permissions
tries to change permissions, a NoPermissionError is raised.
This test ensures:
- Non-editors cannot change permissions
- Error message is clear
- Error type is correct
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=False)
dataset = DatasetPermissionTestDataFactory.create_dataset_mock()
requested_permission = DatasetPermissionEnum.ALL_TEAM
requested_partial_member_list = None
# Act & Assert
with pytest.raises(NoPermissionError, match="User does not have permission to edit this dataset"):
DatasetPermissionService.check_permission(
user, dataset, requested_permission, requested_partial_member_list
)
def test_check_permission_operator_cannot_change_permission_error(self):
"""
Test error when dataset operator tries to change permission level.
Verifies that when a dataset operator tries to change the permission
level, a NoPermissionError is raised.
This test ensures:
- Dataset operators cannot change permission levels
- Error message is clear
- Current permission is preserved
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=True)
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME)
requested_permission = DatasetPermissionEnum.ALL_TEAM # Trying to change
requested_partial_member_list = None
# Act & Assert
with pytest.raises(NoPermissionError, match="Dataset operators cannot change the dataset permissions"):
DatasetPermissionService.check_permission(
user, dataset, requested_permission, requested_partial_member_list
)
def test_check_permission_operator_partial_members_missing_list_error(self, mock_get_partial_member_list):
"""
Test error when operator sets partial_members without providing list.
Verifies that when a dataset operator tries to set permission to
partial_members without providing a member list, a ValueError is raised.
This test ensures:
- Partial member list is required for partial_members permission
- Error message is clear
- Error type is correct
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=True)
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
requested_permission = "partial_members"
requested_partial_member_list = None # Missing list
# Act & Assert
with pytest.raises(ValueError, match="Partial member list is required when setting to partial members"):
DatasetPermissionService.check_permission(
user, dataset, requested_permission, requested_partial_member_list
)
def test_check_permission_operator_cannot_modify_partial_list_error(self, mock_get_partial_member_list):
"""
Test error when operator tries to modify partial member list.
Verifies that when a dataset operator tries to change the partial
member list, a ValueError is raised.
This test ensures:
- Dataset operators cannot modify partial member lists
- Error message is clear
- Current member list is preserved
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=True)
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
requested_permission = "partial_members"
# Current member list
current_member_list = ["user-456", "user-789"]
mock_get_partial_member_list.return_value = current_member_list
# Requested member list (different from current)
requested_partial_member_list = DatasetPermissionTestDataFactory.create_user_list_mock(
["user-456", "user-999"] # Different list
)
# Act & Assert
with pytest.raises(ValueError, match="Dataset operators cannot change the dataset permissions"):
DatasetPermissionService.check_permission(
user, dataset, requested_permission, requested_partial_member_list
)
def test_check_permission_operator_can_keep_same_partial_list(self, mock_get_partial_member_list):
"""
Test that operator can keep the same partial member list.
Verifies that when a dataset operator keeps the same partial member
list, the check passes.
This test ensures:
- Operators can keep existing partial member lists
- No errors are raised for unchanged lists
- Permission validation works correctly
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=True)
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
requested_permission = "partial_members"
# Current member list
current_member_list = ["user-456", "user-789"]
mock_get_partial_member_list.return_value = current_member_list
# Requested member list (same as current)
requested_partial_member_list = DatasetPermissionTestDataFactory.create_user_list_mock(
["user-456", "user-789"] # Same list
)
# Act (should not raise)
DatasetPermissionService.check_permission(user, dataset, requested_permission, requested_partial_member_list)
# Assert
# Verify get_partial_member_list was called to compare lists
mock_get_partial_member_list.assert_called_once_with(dataset.id)
# ============================================================================
# Tests for DatasetService.check_dataset_permission
# ============================================================================
class TestDatasetServiceCheckDatasetPermission:
"""
Comprehensive unit tests for DatasetService.check_dataset_permission method.
This test class covers the dataset permission checking logic that validates
whether a user has access to a dataset based on permission enums.
The check_dataset_permission method:
1. Validates tenant match
2. Checks OWNER role (bypasses some restrictions)
3. Validates only_me permission (creator only)
4. Validates partial_members permission (explicit permission required)
5. Validates all_team_members permission (all tenant members)
Test scenarios include:
- Tenant boundary enforcement
- OWNER role bypass
- only_me permission validation
- partial_members permission validation
- all_team_members permission validation
- Permission denial scenarios
"""
@pytest.fixture
def mock_db_session(self):
"""
Mock database session for testing.
Provides a mocked database session that can be used to verify
database queries for permission checks.
"""
with patch("services.dataset_service.db.session") as mock_db:
yield mock_db
def test_check_dataset_permission_owner_bypass(self, mock_db_session):
"""
Test that OWNER role bypasses permission checks.
Verifies that when a user has OWNER role, they can access any
dataset in their tenant regardless of permission level.
This test ensures:
- OWNER role bypasses permission restrictions
- No database queries are needed for OWNER
- Access is granted automatically
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(role=TenantAccountRole.OWNER, tenant_id="tenant-123")
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
tenant_id="tenant-123",
permission=DatasetPermissionEnum.ONLY_ME,
created_by="other-user-123", # Not the current user
)
# Act (should not raise)
DatasetService.check_dataset_permission(dataset, user)
# Assert
# Verify no permission queries were made (OWNER bypasses)
mock_db_session.query.assert_not_called()
def test_check_dataset_permission_tenant_mismatch_error(self):
"""
Test error when user and dataset are in different tenants.
Verifies that when a user tries to access a dataset from a different
tenant, a NoPermissionError is raised.
This test ensures:
- Tenant boundary is enforced
- Error message is clear
- Error type is correct
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(tenant_id="tenant-123")
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(tenant_id="tenant-456") # Different tenant
# Act & Assert
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
DatasetService.check_dataset_permission(dataset, user)
def test_check_dataset_permission_only_me_creator_success(self):
"""
Test that creator can access only_me dataset.
Verifies that when a user is the creator of an only_me dataset,
they can access it successfully.
This test ensures:
- Creators can access their own only_me datasets
- No explicit permission record is needed
- Access is granted correctly
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL)
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
tenant_id="tenant-123",
permission=DatasetPermissionEnum.ONLY_ME,
created_by="user-123", # User is the creator
)
# Act (should not raise)
DatasetService.check_dataset_permission(dataset, user)
def test_check_dataset_permission_only_me_non_creator_error(self):
"""
Test error when non-creator tries to access only_me dataset.
Verifies that when a user who is not the creator tries to access
an only_me dataset, a NoPermissionError is raised.
This test ensures:
- Non-creators cannot access only_me datasets
- Error message is clear
- Error type is correct
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL)
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
tenant_id="tenant-123",
permission=DatasetPermissionEnum.ONLY_ME,
created_by="other-user-456", # Different creator
)
# Act & Assert
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
DatasetService.check_dataset_permission(dataset, user)
def test_check_dataset_permission_partial_members_creator_success(self, mock_db_session):
"""
Test that creator can access partial_members dataset without explicit permission.
Verifies that when a user is the creator of a partial_members dataset,
they can access it even without an explicit DatasetPermission record.
This test ensures:
- Creators can access their own datasets
- No explicit permission record is needed for creators
- Access is granted correctly
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL)
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
tenant_id="tenant-123",
permission=DatasetPermissionEnum.PARTIAL_TEAM,
created_by="user-123", # User is the creator
)
# Act (should not raise)
DatasetService.check_dataset_permission(dataset, user)
# Assert
# Verify permission query was not executed (creator bypasses)
mock_db_session.query.assert_not_called()
def test_check_dataset_permission_all_team_members_success(self):
"""
Test that any tenant member can access all_team_members dataset.
Verifies that when a dataset has all_team_members permission, any
user in the same tenant can access it.
This test ensures:
- All team members can access
- No explicit permission record is needed
- Access is granted correctly
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL)
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
tenant_id="tenant-123",
permission=DatasetPermissionEnum.ALL_TEAM,
created_by="other-user-456", # Not the creator
)
# Act (should not raise)
DatasetService.check_dataset_permission(dataset, user)
# ============================================================================
# Tests for DatasetService.check_dataset_operator_permission
# ============================================================================
class TestDatasetServiceCheckDatasetOperatorPermission:
"""
Comprehensive unit tests for DatasetService.check_dataset_operator_permission method.
This test class covers the dataset operator permission checking logic,
which validates whether a dataset operator has access to a dataset.
The check_dataset_operator_permission method:
1. Validates dataset exists
2. Validates user exists
3. Checks OWNER role (bypasses restrictions)
4. Validates only_me permission (creator only)
5. Validates partial_members permission (explicit permission required)
Test scenarios include:
- Dataset not found error
- User not found error
- OWNER role bypass
- only_me permission validation
- partial_members permission validation
- Permission denial scenarios
"""
@pytest.fixture
def mock_db_session(self):
"""
Mock database session for testing.
Provides a mocked database session that can be used to verify
database queries for permission checks.
"""
with patch("services.dataset_service.db.session") as mock_db:
yield mock_db
def test_check_dataset_operator_permission_dataset_not_found_error(self):
"""
Test error when dataset is None.
Verifies that when dataset is None, a ValueError is raised.
This test ensures:
- Dataset existence is validated
- Error message is clear
- Error type is correct
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock()
dataset = None
# Act & Assert
with pytest.raises(ValueError, match="Dataset not found"):
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
def test_check_dataset_operator_permission_user_not_found_error(self):
"""
Test error when user is None.
Verifies that when user is None, a ValueError is raised.
This test ensures:
- User existence is validated
- Error message is clear
- Error type is correct
"""
# Arrange
user = None
dataset = DatasetPermissionTestDataFactory.create_dataset_mock()
# Act & Assert
with pytest.raises(ValueError, match="User not found"):
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
def test_check_dataset_operator_permission_owner_bypass(self):
"""
Test that OWNER role bypasses permission checks.
Verifies that when a user has OWNER role, they can access any
dataset in their tenant regardless of permission level.
This test ensures:
- OWNER role bypasses permission restrictions
- No database queries are needed for OWNER
- Access is granted automatically
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(role=TenantAccountRole.OWNER, tenant_id="tenant-123")
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
tenant_id="tenant-123",
permission=DatasetPermissionEnum.ONLY_ME,
created_by="other-user-123", # Not the current user
)
# Act (should not raise)
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
def test_check_dataset_operator_permission_only_me_creator_success(self):
"""
Test that creator can access only_me dataset.
Verifies that when a user is the creator of an only_me dataset,
they can access it successfully.
This test ensures:
- Creators can access their own only_me datasets
- No explicit permission record is needed
- Access is granted correctly
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL)
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
tenant_id="tenant-123",
permission=DatasetPermissionEnum.ONLY_ME,
created_by="user-123", # User is the creator
)
# Act (should not raise)
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
def test_check_dataset_operator_permission_only_me_non_creator_error(self):
"""
Test error when non-creator tries to access only_me dataset.
Verifies that when a user who is not the creator tries to access
an only_me dataset, a NoPermissionError is raised.
This test ensures:
- Non-creators cannot access only_me datasets
- Error message is clear
- Error type is correct
"""
# Arrange
user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL)
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
tenant_id="tenant-123",
permission=DatasetPermissionEnum.ONLY_ME,
created_by="other-user-456", # Different creator
)
# Act & Assert
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
# ============================================================================
# Additional Documentation and Notes
# ============================================================================
#
# This test suite covers the core permission management operations for datasets.
# Additional test scenarios that could be added:
#
# 1. Permission Enum Transitions:
# - Testing transitions between permission levels
# - Testing validation during transitions
# - Testing partial member list updates during transitions
#
# 2. Bulk Operations:
# - Testing bulk permission updates
# - Testing bulk partial member list updates
# - Testing performance with large member lists
#
# 3. Edge Cases:
# - Testing with very large partial member lists
# - Testing with special characters in user IDs
# - Testing with deleted users
# - Testing with inactive permissions
#
# 4. Integration Scenarios:
# - Testing permission changes followed by access attempts
# - Testing concurrent permission updates
# - Testing permission inheritance
#
# These scenarios are not currently implemented but could be added if needed
# based on real-world usage patterns or discovered edge cases.
#
# ============================================================================

View File

@ -1,818 +0,0 @@
"""
Comprehensive unit tests for DatasetService update and delete operations.
This module contains extensive unit tests for the DatasetService class,
specifically focusing on update and delete operations for datasets.
The DatasetService provides methods for:
- Updating dataset configuration and settings (update_dataset)
- Deleting datasets with proper cleanup (delete_dataset)
- Updating RAG pipeline dataset settings (update_rag_pipeline_dataset_settings)
- Checking if dataset is in use (dataset_use_check)
- Updating dataset API access status (update_dataset_api_status)
These operations are critical for dataset lifecycle management and require
careful handling of permissions, dependencies, and data integrity.
This test suite ensures:
- Correct update of dataset properties
- Proper permission validation before updates/deletes
- Cascade deletion handling
- Event signaling for cleanup operations
- RAG pipeline dataset configuration updates
- API status management
- Use check validation
================================================================================
ARCHITECTURE OVERVIEW
================================================================================
The DatasetService update and delete operations are part of the dataset
lifecycle management system. These operations interact with multiple
components:
1. Permission System: All update/delete operations require proper
permission validation to ensure users can only modify datasets they
have access to.
2. Event System: Dataset deletion triggers the dataset_was_deleted event,
which notifies other components to clean up related data (documents,
segments, vector indices, etc.).
3. Dependency Checking: Before deletion, the system checks if the dataset
is in use by any applications (via AppDatasetJoin).
4. RAG Pipeline Integration: RAG pipeline datasets have special update
logic that handles chunk structure, indexing techniques, and embedding
model configuration.
5. API Status Management: Datasets can have their API access enabled or
disabled, which affects whether they can be accessed via the API.
================================================================================
TESTING STRATEGY
================================================================================
This test suite follows a comprehensive testing strategy that covers:
1. Update Operations:
- Internal dataset updates
- External dataset updates
- RAG pipeline dataset updates
- Permission validation
- Name duplicate checking
- Configuration validation
2. Delete Operations:
- Successful deletion
- Permission validation
- Event signaling
- Database cleanup
- Not found handling
3. Use Check Operations:
- Dataset in use detection
- Dataset not in use detection
- AppDatasetJoin query validation
4. API Status Operations:
- Enable API access
- Disable API access
- Permission validation
- Current user validation
5. RAG Pipeline Operations:
- Unpublished dataset updates
- Published dataset updates
- Chunk structure validation
- Indexing technique changes
- Embedding model configuration
================================================================================
"""
import datetime
from unittest.mock import Mock, create_autospec, patch
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
Dataset,
DatasetPermissionEnum,
)
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
# ============================================================================
# Test Data Factory
# ============================================================================
# The Test Data Factory pattern is used here to centralize the creation of
# test objects and mock instances. This approach provides several benefits:
#
# 1. Consistency: All test objects are created using the same factory methods,
# ensuring consistent structure across all tests.
#
# 2. Maintainability: If the structure of models or services changes, we only
# need to update the factory methods rather than every individual test.
#
# 3. Reusability: Factory methods can be reused across multiple test classes,
# reducing code duplication.
#
# 4. Readability: Tests become more readable when they use descriptive factory
# method calls instead of complex object construction logic.
#
# ============================================================================
class DatasetUpdateDeleteTestDataFactory:
"""
Factory class for creating test data and mock objects for dataset update/delete tests.
This factory provides static methods to create mock objects for:
- Dataset instances with various configurations
- User/Account instances with different roles
- Knowledge configuration objects
- Database session mocks
- Event signal mocks
The factory methods help maintain consistency across tests and reduce
code duplication when setting up test scenarios.
"""
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
provider: str = "vendor",
name: str = "Test Dataset",
description: str = "Test description",
tenant_id: str = "tenant-123",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider: str | None = "openai",
embedding_model: str | None = "text-embedding-ada-002",
collection_binding_id: str | None = "binding-123",
enable_api: bool = True,
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
created_by: str = "user-123",
chunk_structure: str | None = None,
runtime_mode: str = "general",
**kwargs,
) -> Mock:
"""
Create a mock Dataset with specified attributes.
Args:
dataset_id: Unique identifier for the dataset
provider: Dataset provider (vendor, external)
name: Dataset name
description: Dataset description
tenant_id: Tenant identifier
indexing_technique: Indexing technique (high_quality, economy)
embedding_model_provider: Embedding model provider
embedding_model: Embedding model name
collection_binding_id: Collection binding ID
enable_api: Whether API access is enabled
permission: Dataset permission level
created_by: ID of user who created the dataset
chunk_structure: Chunk structure for RAG pipeline datasets
runtime_mode: Runtime mode (general, rag_pipeline)
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as a Dataset instance
"""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.provider = provider
dataset.name = name
dataset.description = description
dataset.tenant_id = tenant_id
dataset.indexing_technique = indexing_technique
dataset.embedding_model_provider = embedding_model_provider
dataset.embedding_model = embedding_model
dataset.collection_binding_id = collection_binding_id
dataset.enable_api = enable_api
dataset.permission = permission
dataset.created_by = created_by
dataset.chunk_structure = chunk_structure
dataset.runtime_mode = runtime_mode
dataset.retrieval_model = {}
dataset.keyword_number = 10
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_user_mock(
user_id: str = "user-123",
tenant_id: str = "tenant-123",
role: TenantAccountRole = TenantAccountRole.NORMAL,
is_dataset_editor: bool = True,
**kwargs,
) -> Mock:
"""
Create a mock user (Account) with specified attributes.
Args:
user_id: Unique identifier for the user
tenant_id: Tenant identifier
role: User role (OWNER, ADMIN, NORMAL, etc.)
is_dataset_editor: Whether user has dataset editor permissions
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as an Account instance
"""
user = create_autospec(Account, instance=True)
user.id = user_id
user.current_tenant_id = tenant_id
user.current_role = role
user.is_dataset_editor = is_dataset_editor
for key, value in kwargs.items():
setattr(user, key, value)
return user
@staticmethod
def create_knowledge_configuration_mock(
chunk_structure: str = "tree",
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider: str = "openai",
embedding_model: str = "text-embedding-ada-002",
keyword_number: int = 10,
retrieval_model: dict | None = None,
**kwargs,
) -> Mock:
"""
Create a mock KnowledgeConfiguration entity.
Args:
chunk_structure: Chunk structure type
indexing_technique: Indexing technique
embedding_model_provider: Embedding model provider
embedding_model: Embedding model name
keyword_number: Keyword number for economy indexing
retrieval_model: Retrieval model configuration
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as a KnowledgeConfiguration instance
"""
config = Mock()
config.chunk_structure = chunk_structure
config.indexing_technique = indexing_technique
config.embedding_model_provider = embedding_model_provider
config.embedding_model = embedding_model
config.keyword_number = keyword_number
config.retrieval_model = Mock()
config.retrieval_model.model_dump.return_value = retrieval_model or {
"search_method": "semantic_search",
"top_k": 2,
}
for key, value in kwargs.items():
setattr(config, key, value)
return config
@staticmethod
def create_app_dataset_join_mock(
app_id: str = "app-123",
dataset_id: str = "dataset-123",
**kwargs,
) -> Mock:
"""
Create a mock AppDatasetJoin instance.
Args:
app_id: Application ID
dataset_id: Dataset ID
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as an AppDatasetJoin instance
"""
join = Mock(spec=AppDatasetJoin)
join.app_id = app_id
join.dataset_id = dataset_id
for key, value in kwargs.items():
setattr(join, key, value)
return join
# ============================================================================
# Tests for update_dataset
# ============================================================================
class TestDatasetServiceUpdateDataset:
"""
Comprehensive unit tests for DatasetService.update_dataset method.
This test class covers the dataset update functionality, including
internal and external dataset updates, permission validation, and
name duplicate checking.
The update_dataset method:
1. Retrieves the dataset by ID
2. Validates dataset exists
3. Checks for duplicate names
4. Validates user permissions
5. Routes to appropriate update handler (internal or external)
6. Returns the updated dataset
Test scenarios include:
- Successful internal dataset updates
- Successful external dataset updates
- Permission validation
- Duplicate name detection
- Dataset not found errors
"""
@pytest.fixture
def mock_dataset_service_dependencies(self):
"""
Mock dataset service dependencies for testing.
Provides mocked dependencies including:
- get_dataset method
- check_dataset_permission method
- _has_dataset_same_name method
- Database session
- Current time utilities
"""
with (
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
patch("services.dataset_service.DatasetService._has_dataset_same_name") as mock_has_same_name,
patch("extensions.ext_database.db.session") as mock_db,
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
):
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_naive_utc_now.return_value = current_time
yield {
"get_dataset": mock_get_dataset,
"check_permission": mock_check_perm,
"has_same_name": mock_has_same_name,
"db_session": mock_db,
"naive_utc_now": mock_naive_utc_now,
"current_time": current_time,
}
def test_update_dataset_internal_success(self, mock_dataset_service_dependencies):
"""
Test successful update of an internal dataset.
Verifies that when all validation passes, an internal dataset
is updated correctly through the _update_internal_dataset method.
This test ensures:
- Dataset is retrieved correctly
- Permission is checked
- Name duplicate check is performed
- Internal update handler is called
- Updated dataset is returned
"""
# Arrange
dataset_id = "dataset-123"
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(
dataset_id=dataset_id, provider="vendor", name="Old Name"
)
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
update_data = {
"name": "New Name",
"description": "New Description",
}
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
mock_dataset_service_dependencies["has_same_name"].return_value = False
with patch("services.dataset_service.DatasetService._update_internal_dataset") as mock_update_internal:
mock_update_internal.return_value = dataset
# Act
result = DatasetService.update_dataset(dataset_id, update_data, user)
# Assert
assert result == dataset
# Verify dataset was retrieved
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id)
# Verify permission was checked
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
# Verify name duplicate check was performed
mock_dataset_service_dependencies["has_same_name"].assert_called_once()
# Verify internal update handler was called
mock_update_internal.assert_called_once()
def test_update_dataset_external_success(self, mock_dataset_service_dependencies):
"""
Test successful update of an external dataset.
Verifies that when all validation passes, an external dataset
is updated correctly through the _update_external_dataset method.
This test ensures:
- Dataset is retrieved correctly
- Permission is checked
- Name duplicate check is performed
- External update handler is called
- Updated dataset is returned
"""
# Arrange
dataset_id = "dataset-123"
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(
dataset_id=dataset_id, provider="external", name="Old Name"
)
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
update_data = {
"name": "New Name",
"external_knowledge_id": "new-knowledge-id",
}
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
mock_dataset_service_dependencies["has_same_name"].return_value = False
with patch("services.dataset_service.DatasetService._update_external_dataset") as mock_update_external:
mock_update_external.return_value = dataset
# Act
result = DatasetService.update_dataset(dataset_id, update_data, user)
# Assert
assert result == dataset
# Verify external update handler was called
mock_update_external.assert_called_once()
def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies):
"""
Test error handling when dataset is not found.
Verifies that when the dataset ID doesn't exist, a ValueError
is raised with an appropriate message.
This test ensures:
- Dataset not found error is handled correctly
- No update operations are performed
- Error message is clear
"""
# Arrange
dataset_id = "non-existent-dataset"
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
update_data = {"name": "New Name"}
mock_dataset_service_dependencies["get_dataset"].return_value = None
# Act & Assert
with pytest.raises(ValueError, match="Dataset not found"):
DatasetService.update_dataset(dataset_id, update_data, user)
# Verify no update operations were attempted
mock_dataset_service_dependencies["check_permission"].assert_not_called()
mock_dataset_service_dependencies["has_same_name"].assert_not_called()
def test_update_dataset_duplicate_name_error(self, mock_dataset_service_dependencies):
"""
Test error handling when dataset name already exists.
Verifies that when a dataset with the same name already exists
in the tenant, a ValueError is raised.
This test ensures:
- Duplicate name detection works correctly
- Error message is clear
- No update operations are performed
"""
# Arrange
dataset_id = "dataset-123"
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
update_data = {"name": "Existing Name"}
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
mock_dataset_service_dependencies["has_same_name"].return_value = True # Duplicate exists
# Act & Assert
with pytest.raises(ValueError, match="Dataset name already exists"):
DatasetService.update_dataset(dataset_id, update_data, user)
# Verify permission check was not called (fails before that)
mock_dataset_service_dependencies["check_permission"].assert_not_called()
def test_update_dataset_permission_denied_error(self, mock_dataset_service_dependencies):
"""
Test error handling when user lacks permission.
Verifies that when the user doesn't have permission to update
the dataset, a NoPermissionError is raised.
This test ensures:
- Permission validation works correctly
- Error is raised before any updates
- Error type is correct
"""
# Arrange
dataset_id = "dataset-123"
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
update_data = {"name": "New Name"}
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
mock_dataset_service_dependencies["has_same_name"].return_value = False
mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission")
# Act & Assert
with pytest.raises(NoPermissionError):
DatasetService.update_dataset(dataset_id, update_data, user)
# ============================================================================
# Tests for update_rag_pipeline_dataset_settings
# ============================================================================
class TestDatasetServiceUpdateRagPipelineDatasetSettings:
"""
Comprehensive unit tests for DatasetService.update_rag_pipeline_dataset_settings method.
This test class covers the RAG pipeline dataset settings update functionality,
including chunk structure, indexing technique, and embedding model configuration.
The update_rag_pipeline_dataset_settings method:
1. Validates current_user and tenant
2. Merges dataset into session
3. Handles unpublished vs published datasets differently
4. Updates chunk structure, indexing technique, and retrieval model
5. Configures embedding model for high_quality indexing
6. Updates keyword_number for economy indexing
7. Commits transaction
8. Triggers index update tasks if needed
Test scenarios include:
- Unpublished dataset updates
- Published dataset updates
- Chunk structure validation
- Indexing technique changes
- Embedding model configuration
- Error handling
"""
@pytest.fixture
def mock_session(self):
"""
Mock database session for testing.
Provides a mocked SQLAlchemy session for testing session operations.
"""
return Mock(spec=Session)
@pytest.fixture
def mock_dataset_service_dependencies(self):
"""
Mock dataset service dependencies for testing.
Provides mocked dependencies including:
- current_user context
- ModelManager
- DatasetCollectionBindingService
- Database session operations
- Task scheduling
"""
with (
patch(
"services.dataset_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager,
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
) as mock_get_binding,
patch("services.dataset_service.deal_dataset_index_update_task") as mock_task,
):
mock_current_user.current_tenant_id = "tenant-123"
mock_current_user.id = "user-123"
yield {
"current_user": mock_current_user,
"model_manager": mock_model_manager,
"get_binding": mock_get_binding,
"task": mock_task,
}
def test_update_rag_pipeline_dataset_settings_unpublished_success(
self, mock_session, mock_dataset_service_dependencies
):
"""
Test successful update of unpublished RAG pipeline dataset.
Verifies that when a dataset is not published, all settings can
be updated including chunk structure and indexing technique.
This test ensures:
- Current user validation passes
- Dataset is merged into session
- Chunk structure is updated
- Indexing technique is updated
- Embedding model is configured for high_quality
- Retrieval model is updated
- Dataset is added to session
"""
# Arrange
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(
dataset_id="dataset-123",
runtime_mode="rag_pipeline",
chunk_structure="tree",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
chunk_structure="list",
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
)
# Mock embedding model
mock_embedding_model = Mock()
mock_embedding_model.model_name = "text-embedding-ada-002"
mock_embedding_model.provider = "openai"
mock_embedding_model.credentials = {}
mock_model_schema = Mock()
mock_model_schema.features = []
mock_text_embedding_model = Mock()
mock_text_embedding_model.get_model_schema.return_value = mock_model_schema
mock_embedding_model.model_type_instance = mock_text_embedding_model
mock_model_instance = Mock()
mock_model_instance.get_model_instance.return_value = mock_embedding_model
mock_dataset_service_dependencies["model_manager"].return_value = mock_model_instance
# Mock collection binding
mock_binding = Mock()
mock_binding.id = "binding-123"
mock_dataset_service_dependencies["get_binding"].return_value = mock_binding
mock_session.merge.return_value = dataset
# Act
DatasetService.update_rag_pipeline_dataset_settings(
mock_session, dataset, knowledge_config, has_published=False
)
# Assert
assert dataset.chunk_structure == "list"
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.embedding_model_provider == "openai"
assert dataset.collection_binding_id == "binding-123"
# Verify dataset was added to session
mock_session.add.assert_called_once_with(dataset)
def test_update_rag_pipeline_dataset_settings_published_chunk_structure_error(
self, mock_session, mock_dataset_service_dependencies
):
"""
Test error handling when trying to update chunk structure of published dataset.
Verifies that when a dataset is published and has an existing chunk structure,
attempting to change it raises a ValueError.
This test ensures:
- Chunk structure change is detected
- ValueError is raised with appropriate message
- No updates are committed
"""
# Arrange
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(
dataset_id="dataset-123",
runtime_mode="rag_pipeline",
chunk_structure="tree", # Existing structure
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
chunk_structure="list", # Different structure
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
)
mock_session.merge.return_value = dataset
# Act & Assert
with pytest.raises(ValueError, match="Chunk structure is not allowed to be updated"):
DatasetService.update_rag_pipeline_dataset_settings(
mock_session, dataset, knowledge_config, has_published=True
)
# Verify no commit was attempted
mock_session.commit.assert_not_called()
def test_update_rag_pipeline_dataset_settings_published_economy_error(
self, mock_session, mock_dataset_service_dependencies
):
"""
Test error handling when trying to change to economy indexing on published dataset.
Verifies that when a dataset is published, changing indexing technique to
economy is not allowed and raises a ValueError.
This test ensures:
- Economy indexing change is detected
- ValueError is raised with appropriate message
- No updates are committed
"""
# Arrange
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(
dataset_id="dataset-123",
runtime_mode="rag_pipeline",
indexing_technique=IndexTechniqueType.HIGH_QUALITY, # Current technique
)
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
indexing_technique=IndexTechniqueType.ECONOMY, # Trying to change to economy
)
mock_session.merge.return_value = dataset
# Act & Assert
with pytest.raises(
ValueError, match="Knowledge base indexing technique is not allowed to be updated to economy"
):
DatasetService.update_rag_pipeline_dataset_settings(
mock_session, dataset, knowledge_config, has_published=True
)
def test_update_rag_pipeline_dataset_settings_missing_current_user_error(
self, mock_session, mock_dataset_service_dependencies
):
"""
Test error handling when current_user is missing.
Verifies that when current_user is None or has no tenant ID, a ValueError
is raised.
This test ensures:
- Current user validation works correctly
- Error message is clear
- No updates are performed
"""
# Arrange
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock()
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock()
mock_dataset_service_dependencies["current_user"].current_tenant_id = None # Missing tenant
# Act & Assert
with pytest.raises(ValueError, match="Current user or current tenant not found"):
DatasetService.update_rag_pipeline_dataset_settings(
mock_session, dataset, knowledge_config, has_published=False
)
# ============================================================================
# Additional Documentation and Notes
# ============================================================================
#
# This test suite covers the core update and delete operations for datasets.
# Additional test scenarios that could be added:
#
# 1. Update Operations:
# - Testing with different indexing techniques
# - Testing embedding model provider changes
# - Testing retrieval model updates
# - Testing icon_info updates
# - Testing partial_member_list updates
#
# 2. Delete Operations:
# - Testing cascade deletion of related data
# - Testing event handler execution
# - Testing with datasets that have documents
# - Testing with datasets that have segments
#
# 3. RAG Pipeline Operations:
# - Testing economy indexing technique updates
# - Testing embedding model provider errors
# - Testing keyword_number updates
# - Testing index update task triggering
#
# 4. Integration Scenarios:
# - Testing update followed by delete
# - Testing multiple updates in sequence
# - Testing concurrent update attempts
# - Testing with different user roles
#
# These scenarios are not currently implemented but could be added if needed
# based on real-world usage patterns or discovered edge cases.
#
# ============================================================================

View File

@ -247,10 +247,11 @@ workflow:
dataset_mock = Mock()
dataset_mock.id = "d1"
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock)
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
session = cast(MagicMock, Mock())
service = RagPipelineDslService(session=cast(Session, session))
session.query.return_value.filter_by.return_value.all.return_value = []
session.scalars.return_value.all.return_value = []
account = Mock(current_tenant_id="t1")
result = service.import_rag_pipeline(account=account, import_mode="yaml-content", yaml_content=yaml_content)
@ -320,6 +321,7 @@ workflow:
dataset_mock.id = "d1"
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock)
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.DatasetCollectionBinding", return_value=Mock(id="b1"))
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
service = RagPipelineDslService(session=Mock())
# Mocking self._session.scalar for the pipeline lookup
@ -406,12 +408,14 @@ def test_create_or_update_pipeline_create_new(mocker) -> None:
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", SimpleNamespace(id="u1"))
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow", return_value=Mock())
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
pipeline_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Pipeline")
pipeline_instance = pipeline_cls.return_value
pipeline_instance.tenant_id = "t1"
pipeline_instance.id = "p1"
pipeline_instance.name = "P"
pipeline_instance.is_published = False
session.scalar.return_value = None
result = service._create_or_update_pipeline(pipeline=None, data=data, account=account, dependencies=[])
@ -447,8 +451,7 @@ def test_export_rag_pipeline_dsl_with_workflow(mocker) -> None:
workflow.rag_pipeline_variables = []
workflow.to_dict.return_value = {"graph": {"nodes": []}}
# Mocking single .where() call
session.query.return_value.where.return_value.first.return_value = workflow
session.scalar.return_value = workflow
mocker.patch(
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
return_value=[],
@ -550,7 +553,7 @@ def test_append_workflow_export_data_filters_credentials(mocker) -> None:
]
}
}
session.query.return_value.where.return_value.first.return_value = workflow
session.scalar.return_value = workflow
mocker.patch(
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
return_value=[],
@ -568,7 +571,7 @@ def test_append_workflow_export_data_filters_credentials(mocker) -> None:
def test_create_rag_pipeline_dataset_raises_when_name_conflicts(mocker) -> None:
session = cast(MagicMock, Mock())
service = RagPipelineDslService(session=cast(Session, session))
session.query.return_value.filter_by.return_value.first.return_value = Mock()
session.scalar.return_value = Mock()
create_entity = RagPipelineDatasetCreateEntity(
name="Existing Name",
description="",
@ -584,8 +587,8 @@ def test_create_rag_pipeline_dataset_raises_when_name_conflicts(mocker) -> None:
def test_create_rag_pipeline_dataset_generates_name_when_missing(mocker) -> None:
session = cast(MagicMock, Mock())
service = RagPipelineDslService(session=cast(Session, session))
session.query.return_value.filter_by.return_value.first.return_value = None
session.query.return_value.filter_by.return_value.all.return_value = [Mock(name="Untitled")]
session.scalar.return_value = None
session.scalars.return_value.all.return_value = [Mock(name="Untitled")]
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.generate_incremental_name", return_value="Untitled 2")
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", Mock(id="u1", current_tenant_id="t1"))
mocker.patch.object(
@ -632,7 +635,7 @@ def test_append_workflow_export_data_encrypts_knowledge_retrieval_dataset_ids(mo
]
}
}
session.query.return_value.where.return_value.first.return_value = workflow
session.scalar.return_value = workflow
mocker.patch.object(service, "encrypt_dataset_id", side_effect=lambda dataset_id, tenant_id: f"enc-{dataset_id}")
mocker.patch(
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
@ -727,7 +730,7 @@ def test_create_or_update_pipeline_decrypts_knowledge_retrieval_dataset_ids(mock
},
}
draft_workflow = Mock(id="wf1")
session.query.return_value.where.return_value.first.return_value = draft_workflow
session.scalar.return_value = draft_workflow
mocker.patch.object(service, "decrypt_dataset_id", side_effect=["d1", None])
result = service._create_or_update_pipeline(pipeline=pipeline, data=data, account=account)
@ -743,7 +746,8 @@ def test_create_or_update_pipeline_creates_draft_when_missing(mocker) -> None:
account = Mock(id="u1", current_tenant_id="t1")
pipeline = Mock(id="p1", tenant_id="t1", name="N", description="D")
data = {"rag_pipeline": {"name": "N2", "description": "D2"}, "workflow": {"graph": {"nodes": []}}}
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
workflow_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow")
workflow_cls.return_value.id = "wf-new"
@ -817,7 +821,7 @@ def test_import_rag_pipeline_fails_for_non_string_version_type() -> None:
def test_append_workflow_export_data_raises_when_draft_workflow_missing() -> None:
session = cast(MagicMock, Mock())
service = RagPipelineDslService(session=cast(Session, session))
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with pytest.raises(ValueError, match="Missing draft workflow configuration"):
service._append_workflow_export_data(export_data={}, pipeline=Mock(tenant_id="t1"), include_secret=False)
@ -841,7 +845,7 @@ def test_append_workflow_export_data_keeps_secret_fields_when_include_secret_tru
]
}
}
session.query.return_value.where.return_value.first.return_value = workflow
session.scalar.return_value = workflow
mocker.patch(
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
return_value=[],
@ -1003,7 +1007,8 @@ def test_import_rag_pipeline_sets_default_version_and_kind(mocker) -> None:
)
dataset = Mock(id="d1")
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset)
session.query.return_value.filter_by.return_value.all.return_value = []
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
session.scalars.return_value.all.return_value = []
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.generate_incremental_name", return_value="P")
result = service.import_rag_pipeline(
@ -1061,7 +1066,7 @@ def test_append_workflow_export_data_skips_empty_node_data(mocker) -> None:
workflow = Mock()
workflow.graph_dict = {"nodes": []}
workflow.to_dict.return_value = {"graph": {"nodes": [{"data": {}}, {}]}}
session.query.return_value.where.return_value.first.return_value = workflow
session.scalar.return_value = workflow
mocker.patch(
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
return_value=[],
@ -1246,11 +1251,12 @@ def test_create_or_update_pipeline_saves_dependencies_to_redis(mocker) -> None:
account = Mock(id="u1", current_tenant_id="t1")
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", SimpleNamespace(id="u1"))
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow", return_value=Mock(id="wf-1"))
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
pipeline_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Pipeline")
pipeline = pipeline_cls.return_value
pipeline.tenant_id = "t1"
pipeline.id = "p1"
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
setex = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.redis_client.setex")
dependency = PluginDependency(
type=PluginDependency.Type.Marketplace,

View File

@ -116,81 +116,6 @@ def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_serv
assert has_more is True
def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service) -> None:
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Dataset not found"):
rag_pipeline_service.get_pipeline("tenant-1", "dataset-1")
# --- update_customized_pipeline_template ---
def test_update_customized_pipeline_template_success(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
# First scalar finds the template, second scalar (duplicate check) returns None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, None])
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(
name="new",
description="new desc",
icon_info=IconInfo(icon="🔥"),
)
result = RagPipelineService.update_customized_pipeline_template("tpl-1", info)
assert result.name == "new"
assert result.description == "new desc"
def test_update_customized_pipeline_template_not_found(mocker) -> None:
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(name="x", description="d", icon_info=IconInfo(icon="i"))
with pytest.raises(ValueError, match="Customized pipeline template not found"):
RagPipelineService.update_customized_pipeline_template("tpl-missing", info)
def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
duplicate = SimpleNamespace(name="dup")
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, duplicate])
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(name="dup", description="d", icon_info=IconInfo(icon="i"))
with pytest.raises(ValueError, match="Template name is already exists"):
RagPipelineService.update_customized_pipeline_template("tpl-1", info)
# --- delete_customized_pipeline_template ---
def test_delete_customized_pipeline_template_success(mocker) -> None:
template = SimpleNamespace(id="tpl-1")
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
delete_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.delete")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
RagPipelineService.delete_customized_pipeline_template("tpl-1")
delete_mock.assert_called_once_with(template)
commit_mock.assert_called_once()
def test_delete_customized_pipeline_template_not_found(mocker) -> None:
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
with pytest.raises(ValueError, match="Customized pipeline template not found"):
RagPipelineService.delete_customized_pipeline_template("tpl-missing")
# --- sync_draft_workflow ---

View File

@ -17,8 +17,7 @@ class TestClearFreePlanTenantExpiredLogs:
def mock_session(self):
"""Create a mock database session."""
session = Mock(spec=Session)
session.query.return_value.filter.return_value.all.return_value = []
session.query.return_value.filter.return_value.delete.return_value = 0
session.scalars.return_value.all.return_value = []
return session
@pytest.fixture
@ -54,18 +53,18 @@ class TestClearFreePlanTenantExpiredLogs:
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", [])
# Should not call any database operations
mock_session.query.assert_not_called()
mock_session.scalars.assert_not_called()
mock_storage.save.assert_not_called()
def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids):
"""Test when no related records are found."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.where.return_value.all.return_value = []
mock_session.scalars.return_value.all.return_value = []
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should call query for each related table but find no records
assert mock_session.query.call_count > 0
# Should call scalars for each related table but find no records
assert mock_session.scalars.call_count > 0
mock_storage.save.assert_not_called()
def test_clear_message_related_tables_with_records_and_to_dict(
@ -73,7 +72,7 @@ class TestClearFreePlanTenantExpiredLogs:
):
"""Test when records are found and have to_dict method."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.where.return_value.all.return_value = sample_records
mock_session.scalars.return_value.all.return_value = sample_records
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
@ -104,7 +103,7 @@ class TestClearFreePlanTenantExpiredLogs:
records.append(record)
# Mock records for first table only, empty for others
mock_session.query.return_value.where.return_value.all.side_effect = [
mock_session.scalars.return_value.all.side_effect = [
records,
[],
[],
@ -126,13 +125,13 @@ class TestClearFreePlanTenantExpiredLogs:
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_storage.save.side_effect = Exception("Storage error")
mock_session.query.return_value.where.return_value.all.return_value = sample_records
mock_session.scalars.return_value.all.return_value = sample_records
# Should not raise exception
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should still delete records even if backup fails
assert mock_session.query.return_value.where.return_value.delete.called
assert mock_session.execute.called
def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids):
"""Test that method continues even when record serialization fails."""
@ -141,23 +140,23 @@ class TestClearFreePlanTenantExpiredLogs:
record.id = "record-1"
record.to_dict.side_effect = Exception("Serialization error")
mock_session.query.return_value.where.return_value.all.return_value = [record]
mock_session.scalars.return_value.all.return_value = [record]
# Should not raise exception
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should still delete records even if serialization fails
assert mock_session.query.return_value.where.return_value.delete.called
assert mock_session.execute.called
def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records):
"""Test that deletion is called for found records."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.where.return_value.all.return_value = sample_records
mock_session.scalars.return_value.all.return_value = sample_records
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should call delete for each table that has records
assert mock_session.query.return_value.where.return_value.delete.called
# Should call execute(delete(...)) for each table that has records
assert mock_session.execute.called
def test_clear_message_related_tables_all_serialization_fails_skips_backup_but_deletes(
self, mock_session, sample_message_ids
@ -167,12 +166,12 @@ class TestClearFreePlanTenantExpiredLogs:
record.to_dict.side_effect = Exception("Serialization error")
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.where.return_value.all.return_value = [record]
mock_session.scalars.return_value.all.return_value = [record]
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
mock_storage.save.assert_not_called()
assert mock_session.query.return_value.where.return_value.delete.called
assert mock_session.execute.called
class _ImmediateFuture:
@ -263,42 +262,23 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -
conv1 = SimpleNamespace(id="c1", to_dict=lambda: {"id": "c1"})
log1 = SimpleNamespace(id="l1", to_dict=lambda: {"id": "l1"})
def make_query_with_batches(batches: list[list[object]]):
q = MagicMock()
q.where.return_value = q
q.limit.return_value = q
q.all.side_effect = batches
q.delete.return_value = 1
return q
msg_session_1 = MagicMock()
msg_session_1.query.side_effect = lambda model: (
make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock()
)
msg_session_1.scalars.return_value.all.return_value = [msg1]
msg_session_2 = MagicMock()
msg_session_2.query.side_effect = lambda model: (
make_query_with_batches([[]]) if model == service_module.Message else MagicMock()
)
msg_session_2.scalars.return_value.all.return_value = []
conv_session_1 = MagicMock()
conv_session_1.query.side_effect = lambda model: (
make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock()
)
conv_session_1.scalars.return_value.all.return_value = [conv1]
conv_session_2 = MagicMock()
conv_session_2.query.side_effect = lambda model: (
make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock()
)
conv_session_2.scalars.return_value.all.return_value = []
wal_session_1 = MagicMock()
wal_session_1.query.side_effect = lambda model: (
make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock()
)
wal_session_1.scalars.return_value.all.return_value = [log1]
wal_session_2 = MagicMock()
wal_session_2.query.side_effect = lambda model: (
make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock()
)
wal_session_2.scalars.return_value.all.return_value = []
session_wrappers = [
_sessionmaker_wrapper_for_begin(msg_session_1),
@ -354,9 +334,7 @@ def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: py
# Total tenant count query
count_session = MagicMock()
count_query = MagicMock()
count_query.count.return_value = 2
count_session.query.return_value = count_query
count_session.scalar.return_value = 2
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
@ -421,32 +399,15 @@ def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pyt
# Sessions used:
# 1) total tenant count
# 2) per-batch tenant scan (count + tenant list)
# 2) per-batch tenant scan (interval counts + tenant list)
total_session = MagicMock()
total_query = MagicMock()
total_query.count.return_value = 250
total_session.query.return_value = total_query
batch_session = MagicMock()
q1 = MagicMock()
q1.where.return_value = q1
q1.count.return_value = 200
q2 = MagicMock()
q2.where.return_value = q2
q2.count.return_value = 200
q3 = MagicMock()
q3.where.return_value = q3
q3.count.return_value = 200
q4 = MagicMock()
q4.where.return_value = q4
q4.count.return_value = 50 # choose this interval, then scale it
total_session.scalar.return_value = 250
rows = [SimpleNamespace(id="tenant-a"), SimpleNamespace(id="tenant-b")]
q_rs = MagicMock()
q_rs.where.return_value = q_rs
q_rs.order_by.return_value = rows
batch_session.query.side_effect = [q1, q2, q3, q4, q_rs]
batch_session = MagicMock()
# 4 test intervals queried: 200, 200, 200, 50 — breaks on 50 <= 100 (4th interval = 3h)
batch_session.scalar.side_effect = [200, 200, 200, 50]
batch_session.execute.return_value = rows
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
@ -464,9 +425,7 @@ def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.Mo
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object()))
count_session = MagicMock()
count_query = MagicMock()
count_query.count.return_value = 100
count_session.query.return_value = count_query
count_session.scalar.return_value = 100
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
flask_app = service_module.Flask("test-app")
@ -513,25 +472,13 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon
monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None)
total_session = MagicMock()
total_query = MagicMock()
total_query.count.return_value = 250
total_session.query.return_value = total_query
batch_session = MagicMock()
# Count results for all 5 intervals, all > 100 => take the for-else path.
count_queries = []
for _ in range(5):
q = MagicMock()
q.where.return_value = q
q.count.return_value = 200
count_queries.append(q)
total_session.scalar.return_value = 250
rows = [SimpleNamespace(id="tenant-a")]
q_rs = MagicMock()
q_rs.where.return_value = q_rs
q_rs.order_by.return_value = rows
batch_session.query.side_effect = [*count_queries, q_rs]
batch_session = MagicMock()
# All 5 intervals have > 100 tenants => for-else falls through to min interval (1h)
batch_session.scalar.side_effect = [200, 200, 200, 200, 200]
batch_session.execute.return_value = rows
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
@ -542,8 +489,7 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon
ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[])
assert process_tenant_mock.call_count == 1
assert len(count_queries) == 5
assert batch_session.query.call_count >= 6
assert batch_session.scalar.call_count == 5
def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pytest.MonkeyPatch) -> None:
@ -565,11 +511,7 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte
# Make message/conversation/workflow_app_log loops no-op (empty immediately)
empty_session = MagicMock()
q_empty = MagicMock()
q_empty.where.return_value = q_empty
q_empty.limit.return_value = q_empty
q_empty.all.return_value = []
empty_session.query.return_value = q_empty
empty_session.scalars.return_value.all.return_value = []
session_wrappers = [
_sessionmaker_wrapper_for_begin(empty_session),
_sessionmaker_wrapper_for_begin(empty_session),

View File

@ -577,7 +577,7 @@ class TestDatasetServiceCreationAndUpdate:
def test_update_external_knowledge_binding_updates_changed_binding_values(self):
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
session = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = binding
session.scalar.return_value = binding
session.add = MagicMock()
session_context = _make_session_context(session)
@ -596,7 +596,7 @@ class TestDatasetServiceCreationAndUpdate:
def test_update_external_knowledge_binding_raises_for_missing_binding(self):
session = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = None
session.scalar.return_value = None
session_context = _make_session_context(session)
mock_sessionmaker = MagicMock()

View File

@ -129,7 +129,7 @@ class TestDocumentServiceQueryAndDownloadHelpers:
def test_update_documents_need_summary_updates_matching_documents_and_commits(self):
session = MagicMock()
session.query.return_value.filter.return_value.update.return_value = 2
session.execute.return_value.rowcount = 2
with patch("services.dataset_service.session_factory") as session_factory_mock:
session_factory_mock.create_session.return_value = _make_session_context(session)
@ -1069,6 +1069,33 @@ class TestDocumentServiceCreateValidation:
assert len(knowledge_config.process_rule.rules.pre_processing_rules) == 1
assert knowledge_config.process_rule.rules.pre_processing_rules[0].enabled is False
def test_process_rule_args_validate_hierarchical_defaults_parent_mode_to_paragraph(self):
knowledge_config = KnowledgeConfig(
indexing_technique="economy",
data_source=DataSource(
info_list=InfoList(
data_source_type="upload_file",
file_info_list=FileInfo(file_ids=["file-1"]),
)
),
process_rule=ProcessRule(
mode="hierarchical",
rules=Rule(
pre_processing_rules=[
PreProcessingRule(id="remove_extra_spaces", enabled=True),
],
segmentation=Segmentation(separator="\n", max_tokens=1024),
subchunk_segmentation=Segmentation(separator="\n", max_tokens=512),
),
),
)
DocumentService.process_rule_args_validate(knowledge_config)
assert knowledge_config.process_rule is not None
assert knowledge_config.process_rule.rules is not None
assert knowledge_config.process_rule.rules.parent_mode == "paragraph"
class TestDocumentServiceSaveDocumentWithDatasetId:
"""Unit tests for non-SQL validation branches in save_document_with_dataset_id."""

View File

@ -179,11 +179,11 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_return_true_when_system_oauth_params_exist(self, service, mock_db_session):
mock_db_session.query().first.return_value = MagicMock()
mock_db_session.scalar.return_value = MagicMock()
assert service.is_system_oauth_params_exist(make_id()) is True
def test_should_return_false_when_system_oauth_params_missing(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
assert service.is_system_oauth_params_exist(make_id()) is False
# -----------------------------------------------------------------------
@ -205,7 +205,7 @@ class TestDatasourceProviderService:
def test_should_delete_tenant_config_when_removing_oauth_params(self, service, mock_db_session):
service.remove_oauth_custom_client_params("t1", make_id())
mock_db_session.query().delete.assert_called_once()
mock_db_session.execute.assert_called_once()
# -----------------------------------------------------------------------
# setup_oauth_custom_client_params (315-351)
@ -217,14 +217,14 @@ class TestDatasourceProviderService:
mock_db_session.add.assert_not_called()
def test_should_create_new_config_when_none_exists(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, True)
mock_db_session.add.assert_called_once()
def test_should_update_existing_config_when_record_found(self, service, mock_db_session):
existing = MagicMock()
mock_db_session.query().first.return_value = existing
mock_db_session.scalar.return_value = existing
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, False)
mock_db_session.add.assert_not_called() # update in place, no add
@ -255,7 +255,7 @@ class TestDatasourceProviderService:
def test_should_return_empty_dict_when_credential_not_found(self, service, mock_db_session, mock_user):
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
assert service.get_datasource_credentials("t1", "prov", "org/plug") == {}
def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user):
@ -264,7 +264,7 @@ class TestDatasourceProviderService:
p.auth_type = "oauth2"
p.expires_at = 0 # expired
p.encrypted_credentials = {"tok": "x"}
mock_db_session.query().first.return_value = p
mock_db_session.scalar.return_value = p
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
@ -278,7 +278,7 @@ class TestDatasourceProviderService:
p.auth_type = "api_key"
p.expires_at = -1 # sentinel: never expires
p.encrypted_credentials = {"k": "v"}
mock_db_session.query().first.return_value = p
mock_db_session.scalar.return_value = p
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "plain"}),
@ -292,7 +292,7 @@ class TestDatasourceProviderService:
p.auth_type = "api_key"
p.expires_at = -1
p.encrypted_credentials = {}
mock_db_session.query().first.return_value = p
mock_db_session.scalar.return_value = p
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "v"}),
@ -306,7 +306,7 @@ class TestDatasourceProviderService:
def test_should_return_empty_list_when_no_provider_credentials_exist(self, service, mock_db_session, mock_user):
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
mock_db_session.query().all.return_value = []
mock_db_session.scalars.return_value.all.return_value = []
assert service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") == []
def test_should_refresh_and_return_credentials_when_oauth_expired(self, service, mock_db_session, mock_user):
@ -314,7 +314,7 @@ class TestDatasourceProviderService:
p.auth_type = "oauth2"
p.expires_at = 0
p.encrypted_credentials = {"t": "x"}
mock_db_session.query().all.return_value = [p]
mock_db_session.scalars.return_value.all.return_value = [p]
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
@ -328,22 +328,21 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_provider_not_found_on_name_update(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="not found"):
service.update_datasource_provider_name("t1", make_id(), "new", "cred-id")
def test_should_return_early_when_new_name_matches_current(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
p.name = "same"
mock_db_session.query().first.return_value = p
mock_db_session.scalar.return_value = p
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
p.name = "old_name"
p.is_default = False
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 1 # conflict
mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count
with pytest.raises(ValueError, match="already exists"):
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
@ -351,8 +350,7 @@ class TestDatasourceProviderService:
p = MagicMock(spec=DatasourceProvider)
p.name = "old_name"
p.is_default = False
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
assert p.name == "new_name"
@ -361,7 +359,7 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_target_provider_not_found(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="not found"):
service.set_default_datasource_provider("t1", make_id(), "bad-id")
@ -369,7 +367,7 @@ class TestDatasourceProviderService:
target = MagicMock(spec=DatasourceProvider)
target.provider = "provider"
target.plugin_id = "org/plug"
mock_db_session.query().first.return_value = target
mock_db_session.scalar.return_value = target
service.set_default_datasource_provider("t1", make_id(), "new-id")
assert target.is_default is True
@ -428,13 +426,13 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_use_tenant_config_when_available(self, service, mock_db_session):
mock_db_session.query().first.return_value = MagicMock(client_params={"k": "v"})
mock_db_session.scalar.return_value = MagicMock(client_params={"k": "v"})
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
result = service.get_oauth_client("t1", make_id())
assert result == {"k": "dec"}
def test_should_fallback_to_system_credentials_when_tenant_config_missing(self, service, mock_db_session):
mock_db_session.query().first.side_effect = [None, MagicMock(system_credentials={"k": "sys"})]
mock_db_session.scalar.side_effect = [None, MagicMock(system_credentials={"k": "sys"})]
with (
patch.object(service.provider_manager, "fetch_datasource_provider"),
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=True),
@ -444,7 +442,7 @@ class TestDatasourceProviderService:
def test_should_raise_value_error_when_no_oauth_config_available(self, service, mock_db_session):
"""Neither tenant nor system credentials → raises ValueError."""
mock_db_session.query().first.side_effect = [None, None]
mock_db_session.scalar.side_effect = [None, None]
with (
patch.object(service.provider_manager, "fetch_datasource_provider"),
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=False),
@ -457,15 +455,14 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_add_oauth_provider_successfully_when_name_is_unique(self, service, mock_db_session):
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=[]):
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
mock_db_session.add.assert_called_once()
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
"""Conflict on name results in auto-incremented name, not an error."""
mock_db_session.query().count.return_value = 1 # conflict first, then auto-named
mock_db_session.query().all.return_value = []
mock_db_session.scalar.return_value = 1 # conflict first, then auto-named
with (
patch.object(service, "extract_secret_variables", return_value=[]),
patch.object(service, "generate_next_datasource_provider_name", return_value="new_gen"),
@ -475,8 +472,7 @@ class TestDatasourceProviderService:
def test_should_auto_generate_name_when_none_provided_for_oauth(self, service, mock_db_session):
"""name=None causes auto-generation via generate_next_datasource_provider_name."""
mock_db_session.query().count.return_value = 0
mock_db_session.query().all.return_value = []
mock_db_session.scalar.return_value = 0
with (
patch.object(service, "extract_secret_variables", return_value=[]),
patch.object(service, "generate_next_datasource_provider_name", return_value="auto"),
@ -485,13 +481,13 @@ class TestDatasourceProviderService:
mock_db_session.add.assert_called_once()
def test_should_encrypt_secret_fields_when_adding_oauth_provider(self, service, mock_db_session):
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=["secret_key"]):
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {"secret_key": "value"})
self._enc.encrypt_token.assert_called()
def test_should_acquire_redis_lock_when_adding_oauth_provider(self, service, mock_db_session):
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=[]):
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {})
self._redis.lock.assert_called()
@ -501,23 +497,21 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_credential_id_not_found_on_reauth(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
with patch.object(service, "extract_secret_variables", return_value=[]):
with pytest.raises(ValueError, match="not found"):
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "bad-id")
def test_should_reauthorize_and_commit_when_credential_found(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
with patch.object(service, "extract_secret_variables", return_value=[]):
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 1 # conflict
mock_db_session.query().all.return_value = []
mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count
mock_db_session.scalars.return_value.all.return_value = []
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
service.reauthorize_datasource_oauth_provider(
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
@ -525,16 +519,14 @@ class TestDatasourceProviderService:
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
service.reauthorize_datasource_oauth_provider(None, "t1", make_id(), "u", 9999, {"tok": "val"}, "cred-id")
self._enc.encrypt_token.assert_called()
def test_should_acquire_redis_lock_when_reauthorizing(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
with patch.object(service, "extract_secret_variables", return_value=[]):
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
self._redis.lock.assert_called()
@ -545,13 +537,13 @@ class TestDatasourceProviderService:
def test_should_raise_value_error_when_api_key_name_already_exists(self, service, mock_db_session, mock_user):
"""explicit name supplied + conflict → raises ValueError immediately."""
mock_db_session.query().count.return_value = 1
mock_db_session.scalar.return_value = 1
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
with pytest.raises(ValueError, match="already exists"):
service.add_datasource_api_key_provider("clash", "t1", make_id(), {"sk": "v"})
def test_should_raise_value_error_when_credentials_validation_fails(self, service, mock_db_session, mock_user):
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.return_value = 0
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad cred")),
@ -561,7 +553,7 @@ class TestDatasourceProviderService:
service.add_datasource_api_key_provider("nm", "t1", make_id(), {"k": "v"})
def test_should_add_api_key_provider_and_commit_when_valid(self, service, mock_db_session, mock_user):
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.return_value = 0
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service.provider_manager, "validate_provider_credentials"),
@ -571,7 +563,7 @@ class TestDatasourceProviderService:
mock_db_session.add.assert_called_once()
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.return_value = 0
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service.provider_manager, "validate_provider_credentials"),
@ -694,7 +686,7 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_credential_not_found_on_update(self, service, mock_db_session, mock_user):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
with pytest.raises(ValueError, match="not found"):
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "name")
@ -704,8 +696,7 @@ class TestDatasourceProviderService:
p.name = "old_name"
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "e"}
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 1
mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
with pytest.raises(ValueError, match="already exists"):
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "new_name")
@ -717,8 +708,7 @@ class TestDatasourceProviderService:
p.name = "old_name"
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "e"}
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "extract_secret_variables", return_value=["sk"]),
@ -733,8 +723,7 @@ class TestDatasourceProviderService:
p.name = "old_name"
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "old_enc"}
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "extract_secret_variables", return_value=["sk"]),

View File

@ -124,10 +124,7 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes
existing.disabled_by = "u"
session = MagicMock(name="session")
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = existing
session.query.return_value = query
session.scalar.return_value = existing
create_session_mock = MagicMock(return_value=_SessionContext(session))
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
@ -149,10 +146,7 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes
def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> None:
session = MagicMock(name="session")
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = None
session.query.return_value = query
session.scalar.return_value = None
create_session_mock = MagicMock(return_value=_SessionContext(session))
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
@ -234,10 +228,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat
# New session used after vectorization succeeds (record not found by id nor chunk_id).
session = MagicMock(name="session")
q1 = MagicMock()
q1.filter_by.return_value = q1
q1.first.side_effect = [None, None]
session.query.return_value = q1
session.scalar.side_effect = [None, None]
create_session_mock = MagicMock(return_value=_SessionContext(session))
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
@ -267,10 +258,7 @@ def test_vectorize_summary_final_failure_updates_error_status(monkeypatch: pytes
# error_session should find record and commit status update
error_session = MagicMock(name="error_session")
q = MagicMock()
q.filter_by.return_value = q
q.first.return_value = summary
error_session.query.return_value = q
error_session.scalar.return_value = summary
create_session_mock = MagicMock(return_value=_SessionContext(error_session))
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
@ -302,10 +290,7 @@ def test_batch_create_summary_records_creates_and_updates(monkeypatch: pytest.Mo
existing.enabled = False
session = MagicMock()
query = MagicMock()
query.filter.return_value = query
query.all.return_value = [existing]
session.query.return_value = query
session.scalars.return_value.all.return_value = [existing]
monkeypatch.setattr(
summary_module,
@ -324,10 +309,7 @@ def test_update_summary_record_error_updates_when_exists(monkeypatch: pytest.Mon
record = _summary_record()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
monkeypatch.setattr(
summary_module,
"session_factory",
@ -346,10 +328,7 @@ def test_generate_and_vectorize_summary_success(monkeypatch: pytest.MonkeyPatch)
record = _summary_record(summary_content="")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
monkeypatch.setattr(
summary_module,
@ -373,10 +352,7 @@ def test_generate_and_vectorize_summary_vectorize_failure_sets_error(monkeypatch
record = _summary_record(summary_content="")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
monkeypatch.setattr(
summary_module,
@ -415,10 +391,7 @@ def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch
existing = _summary_record(summary_content="old", node_id="old-node")
existing.id = "other-id"
session = MagicMock(name="session")
q = MagicMock()
q.filter_by.return_value = q
q.first.side_effect = [None, existing] # miss by id, hit by chunk_id
session.query.return_value = q
session.scalar.side_effect = [None, existing] # miss by id, hit by chunk_id
monkeypatch.setattr(
summary_module,
"session_factory",
@ -448,10 +421,7 @@ def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pyte
existing = _summary_record(summary_content="old", node_id="old-node")
session = MagicMock(name="session")
q = MagicMock()
q.filter_by.return_value = q
q.first.return_value = existing # hit by id
session.query.return_value = q
session.scalar.return_value = existing # hit by id
monkeypatch.setattr(
summary_module,
"session_factory",
@ -487,10 +457,7 @@ def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(mon
return None
error_session = MagicMock()
q = MagicMock()
q.filter_by.return_value = q
q.first.return_value = summary
error_session.query.return_value = q
error_session.scalar.return_value = summary
create_session_mock = MagicMock(side_effect=[_BadContext(), _SessionContext(error_session)])
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
@ -516,21 +483,17 @@ def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatc
)
session = MagicMock()
q = MagicMock()
q.filter_by.return_value = q
q.first.side_effect = [None, None] # miss by id and chunk_id
session.query.return_value = q
session.scalar.side_effect = [None, None] # miss by id and chunk_id
error_session = MagicMock()
eq = MagicMock()
eq.filter_by.return_value = eq
eq.first.return_value = summary
error_session.query.return_value = eq
error_session.scalar.return_value = summary
create_session_mock = MagicMock(side_effect=[_SessionContext(session), _SessionContext(error_session)])
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
# Force the created record to be None so the "should not be None" guard triggers.
# Also mock select() so SQLAlchemy doesn't validate the mocked DocumentSegmentSummary as a real column clause.
monkeypatch.setattr(summary_module, "select", MagicMock(return_value=MagicMock()))
monkeypatch.setattr(summary_module, "DocumentSegmentSummary", MagicMock(return_value=None))
with pytest.raises(RuntimeError, match="summary_record_in_session should not be None"):
@ -554,10 +517,7 @@ def test_vectorize_summary_error_handler_tries_chunk_id_lookup_and_can_warn_not_
)
error_session = MagicMock(name="error_session")
q = MagicMock()
q.filter_by.return_value = q
q.first.side_effect = [None, None] # not found by id, not found by chunk_id
error_session.query.return_value = q
error_session.scalar.side_effect = [None, None] # not found by id, not found by chunk_id
monkeypatch.setattr(
summary_module,
@ -577,10 +537,7 @@ def test_update_summary_record_error_warns_when_missing(monkeypatch: pytest.Monk
segment = _segment()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = None
session.query.return_value = query
session.scalar.return_value = None
monkeypatch.setattr(
summary_module,
"session_factory",
@ -599,10 +556,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo
segment = _segment()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = None
session.query.return_value = query
session.scalar.return_value = None
monkeypatch.setattr(
summary_module,
"session_factory",
@ -646,11 +600,7 @@ def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: py
seg2.id = "seg-2"
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = [seg1, seg2]
session.query.return_value = query
session.scalars.return_value.all.return_value = [seg1, seg2]
monkeypatch.setattr(
summary_module,
@ -678,11 +628,7 @@ def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch:
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = []
session.query.return_value = query
session.scalars.return_value.all.return_value = []
monkeypatch.setattr(
summary_module,
"session_factory",
@ -702,11 +648,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu
seg = _segment()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = [seg]
session.query.return_value = query
session.scalars.return_value.all.return_value = [seg]
monkeypatch.setattr(
summary_module,
"session_factory",
@ -723,7 +665,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu
segment_ids=[seg.id],
only_parent_chunks=True,
)
query.filter.assert_called()
session.scalars.assert_called()
def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: pytest.MonkeyPatch) -> None:
@ -732,11 +674,7 @@ def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch:
summary2 = _summary_record(summary_content="s", node_id=None)
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = [summary1, summary2]
session.query.return_value = query
session.scalars.return_value.all.return_value = [summary1, summary2]
monkeypatch.setattr(
summary_module,
@ -761,11 +699,7 @@ def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch:
def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _dataset()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = []
session.query.return_value = query
session.scalars.return_value.all.return_value = []
monkeypatch.setattr(
summary_module,
"session_factory",
@ -793,21 +727,8 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt
segment.status = SegmentStatus.COMPLETED
session = MagicMock()
summary_query = MagicMock()
summary_query.filter_by.return_value = summary_query
summary_query.filter.return_value = summary_query
summary_query.all.return_value = [summary]
seg_query = MagicMock()
seg_query.filter_by.return_value = seg_query
seg_query.first.return_value = segment
def query_side_effect(model: object) -> MagicMock:
if model is summary_module.DocumentSegmentSummary:
return summary_query
return seg_query
session.query.side_effect = query_side_effect
session.scalars.return_value.all.return_value = [summary]
session.scalar.return_value = segment
monkeypatch.setattr(
summary_module,
@ -826,11 +747,7 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt
def test_enable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _dataset()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = []
session.query.return_value = query
session.scalars.return_value.all.return_value = []
monkeypatch.setattr(
summary_module,
"session_factory",
@ -860,21 +777,9 @@ def test_enable_summaries_for_segments_skips_segment_or_content_and_handles_vect
good_segment.status = SegmentStatus.COMPLETED
session = MagicMock()
summary_query = MagicMock()
summary_query.filter_by.return_value = summary_query
summary_query.filter.return_value = summary_query
summary_query.all.return_value = [summary1, summary2, summary3]
session.scalars.return_value.all.return_value = [summary1, summary2, summary3]
session.scalar.side_effect = [bad_segment, good_segment, good_segment]
seg_query = MagicMock()
seg_query.filter_by.return_value = seg_query
seg_query.first.side_effect = [bad_segment, good_segment, good_segment]
def query_side_effect(model: object) -> MagicMock:
if model is summary_module.DocumentSegmentSummary:
return summary_query
return seg_query
session.query.side_effect = query_side_effect
monkeypatch.setattr(
summary_module,
"session_factory",
@ -895,11 +800,7 @@ def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch:
summary = _summary_record(summary_content="sum", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = [summary]
session.query.return_value = query
session.scalars.return_value.all.return_value = [summary]
vector_instance = MagicMock()
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
@ -918,11 +819,7 @@ def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch:
def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _dataset()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = []
session.query.return_value = query
session.scalars.return_value.all.return_value = []
monkeypatch.setattr(
summary_module,
"session_factory",
@ -946,10 +843,7 @@ def test_update_summary_for_segment_empty_content_deletes_existing(monkeypatch:
record = _summary_record(summary_content="old", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
vector_instance = MagicMock()
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
@ -971,10 +865,7 @@ def test_update_summary_for_segment_empty_content_delete_vector_warns(monkeypatc
record = _summary_record(summary_content="old", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
monkeypatch.setattr(
summary_module,
"session_factory",
@ -996,10 +887,7 @@ def test_update_summary_for_segment_empty_content_no_record_noop(monkeypatch: py
segment = _segment()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = None
session.query.return_value = query
session.scalar.return_value = None
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1015,10 +903,7 @@ def test_update_summary_for_segment_updates_existing_and_vectorizes(monkeypatch:
record = _summary_record(summary_content="old", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
vector_instance = MagicMock()
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
@ -1044,10 +929,7 @@ def test_update_summary_for_segment_existing_vector_delete_warns(monkeypatch: py
record = _summary_record(summary_content="old", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1073,10 +955,7 @@ def test_update_summary_for_segment_existing_vectorize_failure_returns_error_rec
record = _summary_record(summary_content="old", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1095,10 +974,7 @@ def test_update_summary_for_segment_new_record_success(monkeypatch: pytest.Monke
segment = _segment()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = None
session.query.return_value = query
session.scalar.return_value = None
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1122,10 +998,7 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk
record = _summary_record(summary_content="old", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
session.flush.side_effect = RuntimeError("flush boom")
monkeypatch.setattr(
summary_module,
@ -1143,25 +1016,9 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk
def test_get_segment_summary_and_document_summaries(monkeypatch: pytest.MonkeyPatch) -> None:
record = _summary_record(summary_content="sum", node_id="n1")
session = MagicMock()
session.scalar.return_value = record
session.scalars.return_value.all.return_value = [record]
q1 = MagicMock()
q1.where.return_value = q1
q1.first.return_value = record
q2 = MagicMock()
q2.filter.return_value = q2
q2.all.return_value = [record]
def query_side_effect(model: object) -> MagicMock:
if model is summary_module.DocumentSegmentSummary:
# first call used by get_segment_summary, second by get_document_summaries
if not hasattr(query_side_effect, "_called"):
query_side_effect._called = True # type: ignore[attr-defined]
return q1
return q2
return MagicMock()
session.query.side_effect = query_side_effect
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1178,10 +1035,7 @@ def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> No
record2 = _summary_record()
record2.chunk_id = "seg-2"
session = MagicMock()
q = MagicMock()
q.where.return_value = q
q.all.return_value = [record1, record2]
session.query.return_value = q
session.scalars.return_value.all.return_value = [record1, record2]
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1194,10 +1048,7 @@ def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> No
def test_get_document_summary_index_status_no_segments_returns_none(monkeypatch: pytest.MonkeyPatch) -> None:
session = MagicMock()
q = MagicMock()
q.where.return_value = q
q.all.return_value = []
session.query.return_value = q
session.scalars.return_value.all.return_value = []
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1212,10 +1063,7 @@ def test_get_documents_summary_index_status_empty_input(monkeypatch: pytest.Monk
def test_get_documents_summary_index_status_no_pending_sets_none(monkeypatch: pytest.MonkeyPatch) -> None:
session = MagicMock()
q = MagicMock()
q.where.return_value = q
q.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")]
session.query.return_value = q
session.execute.return_value.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")]
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1237,10 +1085,7 @@ def test_update_summary_for_segment_creates_new_and_vectorize_fails_returns_erro
segment = _segment()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = None
session.query.return_value = query
session.scalar.return_value = None
monkeypatch.setattr(
summary_module,
@ -1267,10 +1112,7 @@ def test_get_segments_summaries_empty_list() -> None:
def test_get_document_summary_index_status_and_documents_status(monkeypatch: pytest.MonkeyPatch) -> None:
seg_row = SimpleNamespace(id="seg-1", document_id="doc-1")
session = MagicMock()
query = MagicMock()
query.where.return_value = query
query.all.return_value = [SimpleNamespace(id="seg-1")]
session.query.return_value = query
session.scalars.return_value.all.return_value = ["seg-1"] # get_document_summary_index_status returns IDs
create_session_mock = MagicMock(return_value=_SessionContext(session))
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
@ -1283,11 +1125,8 @@ def test_get_document_summary_index_status_and_documents_status(monkeypatch: pyt
assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") == "SUMMARIZING"
# Multiple docs
query2 = MagicMock()
query2.where.return_value = query2
query2.all.return_value = [seg_row]
session2 = MagicMock()
session2.query.return_value = query2
session2.execute.return_value.all.return_value = [seg_row] # get_documents_summary_index_status uses execute
monkeypatch.setattr(
summary_module,
"session_factory",

View File

@ -124,9 +124,7 @@ def test_list_trigger_provider_subscriptions_should_return_empty_list_when_no_su
provider_id: TriggerProviderID,
) -> None:
# Arrange
query = MagicMock()
query.filter_by.return_value.order_by.return_value.all.return_value = []
mock_session.query.return_value = query
mock_session.scalars.return_value.all.return_value = []
# Act
result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id)
@ -152,11 +150,8 @@ def test_list_trigger_provider_subscriptions_should_mask_fields_and_attach_workf
db_sub = SimpleNamespace(to_api_entity=lambda: api_sub)
usage_row = SimpleNamespace(subscription_id="sub-1", app_count=2)
query_subs = MagicMock()
query_subs.filter_by.return_value.order_by.return_value.all.return_value = [db_sub]
query_usage = MagicMock()
query_usage.filter.return_value.group_by.return_value.all.return_value = [usage_row]
mock_session.query.side_effect = [query_subs, query_usage]
mock_session.scalars.return_value.all.return_value = [db_sub]
mock_session.execute.return_value.all.return_value = [usage_row]
_mock_get_trigger_provider(mocker, provider_controller)
cred_enc = _encrypter_mock(decrypted={"token": "plain"}, masked={"token": "****"})
@ -188,11 +183,7 @@ def test_add_trigger_subscription_should_create_subscription_successfully_for_ap
) -> None:
# Arrange
_patch_redis_lock(mocker)
query_count = MagicMock()
query_count.filter_by.return_value.count.return_value = 0
query_existing = MagicMock()
query_existing.filter_by.return_value.first.return_value = None
mock_session.query.side_effect = [query_count, query_existing]
mock_session.scalar.side_effect = [0, None] # count=0, no existing name
_mock_get_trigger_provider(mocker, provider_controller)
cred_enc = _encrypter_mock(encrypted={"api_key": "enc"})
@ -228,11 +219,7 @@ def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorize
) -> None:
# Arrange
_patch_redis_lock(mocker)
query_count = MagicMock()
query_count.filter_by.return_value.count.return_value = 0
query_existing = MagicMock()
query_existing.filter_by.return_value.first.return_value = None
mock_session.query.side_effect = [query_count, query_existing]
mock_session.scalar.side_effect = [0, None] # count=0, no existing name
_mock_get_trigger_provider(mocker, provider_controller)
prop_enc = _encrypter_mock(encrypted={"p": "enc"})
@ -267,9 +254,7 @@ def test_add_trigger_subscription_should_raise_error_when_provider_limit_reached
) -> None:
# Arrange
_patch_redis_lock(mocker)
query_count = MagicMock()
query_count.filter_by.return_value.count.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__
mock_session.query.return_value = query_count
mock_session.scalar.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__
_mock_get_trigger_provider(mocker, provider_controller)
mock_logger = mocker.patch("services.trigger.trigger_provider_service.logger")
@ -297,11 +282,7 @@ def test_add_trigger_subscription_should_raise_error_when_name_exists(
) -> None:
# Arrange
_patch_redis_lock(mocker)
query_count = MagicMock()
query_count.filter_by.return_value.count.return_value = 0
query_existing = MagicMock()
query_existing.filter_by.return_value.first.return_value = object()
mock_session.query.side_effect = [query_count, query_existing]
mock_session.scalar.side_effect = [0, object()] # count=0, existing name conflict
_mock_get_trigger_provider(mocker, provider_controller)
# Act + Assert
@ -325,9 +306,7 @@ def test_update_trigger_subscription_should_raise_error_when_subscription_not_fo
) -> None:
# Arrange
_patch_redis_lock(mocker)
query_sub = MagicMock()
query_sub.filter_by.return_value.first.return_value = None
mock_session.query.return_value = query_sub
mock_session.scalar.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="not found"):
@ -347,11 +326,7 @@ def test_update_trigger_subscription_should_raise_error_when_name_conflicts(
provider_id="langgenius/github/github",
credential_type=CredentialType.API_KEY.value,
)
query_sub = MagicMock()
query_sub.filter_by.return_value.first.return_value = subscription
query_existing = MagicMock()
query_existing.filter_by.return_value.first.return_value = object()
mock_session.query.side_effect = [query_sub, query_existing]
mock_session.scalar.side_effect = [subscription, object()] # found sub, name conflict
_mock_get_trigger_provider(mocker, provider_controller)
# Act + Assert
@ -378,11 +353,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache(
credential_expires_at=0,
expires_at=0,
)
query_sub = MagicMock()
query_sub.filter_by.return_value.first.return_value = subscription
query_existing = MagicMock()
query_existing.filter_by.return_value.first.return_value = None
mock_session.query.side_effect = [query_sub, query_existing]
mock_session.scalar.side_effect = [subscription, None] # found sub, no name conflict
_mock_get_trigger_provider(mocker, provider_controller)
prop_enc = _encrypter_mock(decrypted={"project": "old-value"}, encrypted={"project": "new-value"})
@ -417,7 +388,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache(
def test_get_subscription_by_id_should_return_none_when_missing(mocker: MockerFixture, mock_session: MagicMock) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Act
result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1")
@ -439,7 +410,7 @@ def test_get_subscription_by_id_should_decrypt_credentials_and_properties(
credentials={"token": "enc"},
properties={"project": "enc"},
)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
_mock_get_trigger_provider(mocker, provider_controller)
cred_enc = _encrypter_mock(decrypted={"token": "plain"})
prop_enc = _encrypter_mock(decrypted={"project": "plain"})
@ -466,7 +437,7 @@ def test_delete_trigger_provider_should_raise_error_when_subscription_missing(
mock_session: MagicMock,
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="not found"):
@ -488,7 +459,7 @@ def test_delete_trigger_provider_should_delete_and_clear_cache_even_if_unsubscri
credentials={"token": "enc"},
to_entity=lambda: SimpleNamespace(id="sub-1"),
)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
_mock_get_trigger_provider(mocker, provider_controller)
cred_enc = _encrypter_mock(decrypted={"token": "plain"})
mocker.patch(
@ -524,7 +495,7 @@ def test_delete_trigger_provider_should_skip_unsubscribe_for_unauthorized(
credentials={},
to_entity=lambda: SimpleNamespace(id="sub-2"),
)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
_mock_get_trigger_provider(mocker, provider_controller)
mock_unsubscribe = mocker.patch("services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger")
mocker.patch(
@ -544,7 +515,7 @@ def test_refresh_oauth_token_should_raise_error_when_subscription_missing(
mocker: MockerFixture, mock_session: MagicMock
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="not found"):
@ -556,7 +527,7 @@ def test_refresh_oauth_token_should_raise_error_for_non_oauth_credentials(
) -> None:
# Arrange
subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
# Act + Assert
with pytest.raises(ValueError, match="Only OAuth credentials can be refreshed"):
@ -577,7 +548,7 @@ def test_refresh_oauth_token_should_refresh_and_persist_new_credentials(
credentials={"access_token": "enc"},
credential_expires_at=0,
)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
_mock_get_trigger_provider(mocker, provider_controller)
cache = MagicMock()
cred_enc = _encrypter_mock(decrypted={"access_token": "old"}, encrypted={"access_token": "new"})
@ -606,7 +577,7 @@ def test_refresh_subscription_should_raise_error_when_subscription_missing(
mocker: MockerFixture, mock_session: MagicMock
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="not found"):
@ -616,7 +587,7 @@ def test_refresh_subscription_should_raise_error_when_subscription_missing(
def test_refresh_subscription_should_skip_when_not_due(mocker: MockerFixture, mock_session: MagicMock) -> None:
# Arrange
subscription = SimpleNamespace(expires_at=200)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
# Act
result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100)
@ -643,7 +614,7 @@ def test_refresh_subscription_should_refresh_and_persist_properties(
credentials={"c": "enc"},
credential_type=CredentialType.API_KEY.value,
)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
_mock_get_trigger_provider(mocker, provider_controller)
cred_enc = _encrypter_mock(decrypted={"c": "plain"})
prop_cache = MagicMock()
@ -681,10 +652,7 @@ def test_get_oauth_client_should_return_tenant_client_when_available(
) -> None:
# Arrange
tenant_client = SimpleNamespace(oauth_params={"client_id": "enc"})
system_client = None
query_tenant = MagicMock()
query_tenant.filter_by.return_value.first.return_value = tenant_client
mock_session.query.return_value = query_tenant
mock_session.scalar.return_value = tenant_client
_mock_get_trigger_provider(mocker, provider_controller)
enc = _encrypter_mock(decrypted={"client_id": "plain"})
mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock()))
@ -703,11 +671,7 @@ def test_get_oauth_client_should_return_none_when_plugin_not_verified(
provider_controller: MagicMock,
) -> None:
# Arrange
query_tenant = MagicMock()
query_tenant.filter_by.return_value.first.return_value = None
query_system = MagicMock()
query_system.filter_by.return_value.first.return_value = None
mock_session.query.side_effect = [query_tenant, query_system]
mock_session.scalar.return_value = None # no tenant client; plugin not verified → early return
_mock_get_trigger_provider(mocker, provider_controller)
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False)
@ -725,11 +689,7 @@ def test_get_oauth_client_should_return_decrypted_system_client_when_verified(
provider_controller: MagicMock,
) -> None:
# Arrange
query_tenant = MagicMock()
query_tenant.filter_by.return_value.first.return_value = None
query_system = MagicMock()
query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc")
mock_session.query.side_effect = [query_tenant, query_system]
mock_session.scalar.side_effect = [None, SimpleNamespace(encrypted_oauth_params="enc")]
_mock_get_trigger_provider(mocker, provider_controller)
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
mocker.patch(
@ -751,11 +711,7 @@ def test_get_oauth_client_should_raise_error_when_system_decryption_fails(
provider_controller: MagicMock,
) -> None:
# Arrange
query_tenant = MagicMock()
query_tenant.filter_by.return_value.first.return_value = None
query_system = MagicMock()
query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc")
mock_session.query.side_effect = [query_tenant, query_system]
mock_session.scalar.side_effect = [None, SimpleNamespace(encrypted_oauth_params="enc")]
_mock_get_trigger_provider(mocker, provider_controller)
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
mocker.patch(
@ -794,7 +750,7 @@ def test_is_oauth_system_client_exists_should_reflect_database_record(
provider_controller: MagicMock,
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = object() if has_client else None
mock_session.scalar.return_value = object() if has_client else None
_mock_get_trigger_provider(mocker, provider_controller)
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
@ -823,11 +779,11 @@ def test_save_custom_oauth_client_params_should_create_record_and_clear_params_w
provider_controller: MagicMock,
) -> None:
# Arrange
query = MagicMock()
query.filter_by.return_value.first.return_value = None
mock_session.query.return_value = query
mock_session.scalar.return_value = None
_mock_get_trigger_provider(mocker, provider_controller)
fake_model = SimpleNamespace(encrypted_oauth_params="", enabled=False, oauth_params={})
# Also mock select() so SQLAlchemy doesn't validate the patched TriggerOAuthTenantClient.
mocker.patch("services.trigger.trigger_provider_service.select", MagicMock(return_value=MagicMock()))
mocker.patch("services.trigger.trigger_provider_service.TriggerOAuthTenantClient", return_value=fake_model)
# Act
@ -853,7 +809,7 @@ def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_c
) -> None:
# Arrange
custom_client = SimpleNamespace(oauth_params={"client_id": "enc-old"}, enabled=False)
mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client
mock_session.scalar.return_value = custom_client
_mock_get_trigger_provider(mocker, provider_controller)
cache = MagicMock()
enc = _encrypter_mock(decrypted={"client_id": "old-id"}, encrypted={"client_id": "new-id"})
@ -882,7 +838,7 @@ def test_get_custom_oauth_client_params_should_return_empty_when_record_missing(
provider_id: TriggerProviderID,
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Act
result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id)
@ -899,7 +855,7 @@ def test_get_custom_oauth_client_params_should_return_masked_decrypted_values(
) -> None:
# Arrange
custom_client = SimpleNamespace(oauth_params={"client_id": "enc"})
mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client
mock_session.scalar.return_value = custom_client
_mock_get_trigger_provider(mocker, provider_controller)
enc = _encrypter_mock(decrypted={"client_id": "plain"}, masked={"client_id": "pl***id"})
mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock()))
@ -916,9 +872,6 @@ def test_delete_custom_oauth_client_params_should_delete_record_and_commit(
mock_session: MagicMock,
provider_id: TriggerProviderID,
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.delete.return_value = 1
# Act
result = TriggerProviderService.delete_custom_oauth_client_params("tenant-1", provider_id)
@ -934,7 +887,7 @@ def test_is_oauth_custom_client_enabled_should_return_expected_boolean(
provider_id: TriggerProviderID,
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = object() if exists else None
mock_session.scalar.return_value = object() if exists else None
# Act
result = TriggerProviderService.is_oauth_custom_client_enabled("tenant-1", provider_id)
@ -947,7 +900,7 @@ def test_get_subscription_by_endpoint_should_return_none_when_not_found(
mocker: MockerFixture, mock_session: MagicMock
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Act
result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1")
@ -968,7 +921,7 @@ def test_get_subscription_by_endpoint_should_decrypt_credentials_and_properties(
credentials={"token": "enc"},
properties={"hook": "enc"},
)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
_mock_get_trigger_provider(mocker, provider_controller)
mocker.patch(
"services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription",

View File

@ -969,8 +969,7 @@ class TestWorkflowService:
# 1. Workflow exists
# 2. No app is currently using it
# 3. Not published as a tool
mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it
mock_session.query.return_value.where.return_value.first.return_value = None # no tool provider
mock_session.scalar.side_effect = [mock_workflow, None, None] # workflow, no app using it, no tool provider
with patch("services.workflow_service.select") as mock_select:
mock_stmt = MagicMock()
@ -1045,8 +1044,7 @@ class TestWorkflowService:
mock_tool_provider = MagicMock()
mock_session = MagicMock()
mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it
mock_session.query.return_value.where.return_value.first.return_value = mock_tool_provider
mock_session.scalar.side_effect = [mock_workflow, None, mock_tool_provider] # workflow, no app, tool provider
with patch("services.workflow_service.select") as mock_select:
mock_stmt = MagicMock()

View File

@ -32,7 +32,7 @@ class TestDeleteCustomOauthClientParams:
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
assert result == {"result": "success"}
session.query.return_value.filter_by.return_value.delete.assert_called_once()
session.execute.assert_called_once()
class TestListBuiltinToolProviderTools:
@ -111,7 +111,7 @@ class TestIsOauthSystemClientExists:
@patch(f"{MODULE}.db")
def test_true_when_exists(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = MagicMock()
session.scalar.return_value = MagicMock()
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is True
@ -119,7 +119,7 @@ class TestIsOauthSystemClientExists:
@patch(f"{MODULE}.db")
def test_false_when_missing(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = None
session.scalar.return_value = None
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is False
@ -129,7 +129,7 @@ class TestIsOauthCustomClientEnabled:
@patch(f"{MODULE}.db")
def test_true_when_enabled(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = MagicMock(enabled=True)
session.scalar.return_value = MagicMock(enabled=True)
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is True
@ -137,7 +137,7 @@ class TestIsOauthCustomClientEnabled:
@patch(f"{MODULE}.db")
def test_false_when_none(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = None
session.scalar.return_value = None
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is False
@ -149,7 +149,7 @@ class TestDeleteBuiltinToolProvider:
@patch(f"{MODULE}.db")
def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
session = _mock_sessionmaker(mock_sm_cls)
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with pytest.raises(ValueError, match="you have not added provider"):
BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "id")
@ -161,7 +161,7 @@ class TestDeleteBuiltinToolProvider:
def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
session = _mock_sessionmaker(mock_sm_cls)
db_provider = MagicMock()
session.query.return_value.where.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
mock_cache = MagicMock()
mock_enc.return_value = (MagicMock(), mock_cache)
@ -177,7 +177,7 @@ class TestSetDefaultProvider:
@patch(f"{MODULE}.db")
def test_raises_when_not_found(self, mock_db, mock_sm_cls):
session = _mock_sessionmaker(mock_sm_cls)
session.query.return_value.filter_by.return_value.first.return_value = None
session.scalar.return_value = None
with pytest.raises(ValueError, match="provider not found"):
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
@ -187,7 +187,7 @@ class TestSetDefaultProvider:
def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls):
session = _mock_sessionmaker(mock_sm_cls)
target = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = target
session.scalar.return_value = target
result = BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
@ -200,7 +200,7 @@ class TestUpdateBuiltinToolProvider:
@patch(f"{MODULE}.db")
def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls):
session = _mock_sessionmaker(mock_sm_cls)
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with pytest.raises(ValueError, match="you have not added provider"):
BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c")
@ -213,7 +213,7 @@ class TestUpdateBuiltinToolProvider:
def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc):
session = _mock_sessionmaker(mock_sm_cls)
db_provider = MagicMock(credential_type="api_key", credentials="{}")
session.query.return_value.where.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
mock_cred_instance = MagicMock()
mock_cred_instance.is_editable.return_value = True
@ -274,7 +274,7 @@ class TestGetOauthClient:
mock_create_enc.return_value = (mock_encrypter, MagicMock())
user_client = MagicMock(oauth_params='{"encrypted": "data"}')
session.query.return_value.filter_by.return_value.first.return_value = user_client
session.scalar.return_value = user_client
result = BuiltinToolManageService.get_oauth_client("t", "google")
@ -297,10 +297,7 @@ class TestGetOauthClient:
mock_create_enc.return_value = (MagicMock(), MagicMock())
system_client = MagicMock(encrypted_oauth_params="enc")
session.query.return_value.filter_by.return_value.first.side_effect = [
None, # user client
system_client, # system client
]
session.scalar.side_effect = [None, system_client]
result = BuiltinToolManageService.get_oauth_client("t", "google")
@ -325,7 +322,7 @@ class TestGetCustomOauthClientParams:
@patch(f"{MODULE}.db")
def test_returns_empty_when_none(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = None
session.scalar.return_value = None
result = BuiltinToolManageService.get_custom_oauth_client_params("t", "p")
@ -391,7 +388,7 @@ class TestGetBuiltinProvider:
session = _mock_session(mock_session_cls)
mock_prov_id.return_value.provider_name = "google"
mock_prov_id.return_value.organization = "langgenius"
session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
session.scalar.return_value = None
result = BuiltinToolManageService.get_builtin_provider("google", "t")
@ -417,7 +414,7 @@ class TestGetBuiltinProvider:
return m
mock_prov_id.side_effect = prov_id_side_effect
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
result = BuiltinToolManageService.get_builtin_provider("google", "t")
@ -439,7 +436,7 @@ class TestGetBuiltinProvider:
mock_prov_id.side_effect = prov_id_side_effect
db_provider = MagicMock(provider="third-party/custom/custom-tool")
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
result = BuiltinToolManageService.get_builtin_provider("third-party/custom/custom-tool", "t")
@ -452,7 +449,7 @@ class TestGetBuiltinProvider:
session = _mock_session(mock_session_cls)
mock_prov_id.side_effect = Exception("parse error")
fallback = MagicMock()
session.query.return_value.where.return_value.order_by.return_value.first.return_value = fallback
session.scalar.return_value = fallback
result = BuiltinToolManageService.get_builtin_provider("old-provider", "t")

View File

@ -82,8 +82,8 @@ def mock_db_session():
"""Mock session_factory.create_session() to return a session whose queries use shared test data.
Tests set session._shared_data = {"dataset": <Dataset>, "documents": [<Document>, ...]}
This fixture makes session.query(Dataset).first() return the shared dataset,
and session.query(Document).all()/first() return from the shared documents.
This fixture makes session.scalar(select(Dataset)...) return the shared dataset,
and session.scalars(select(Document)...).all() return the shared documents.
"""
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
session = MagicMock()
@ -92,93 +92,68 @@ def mock_db_session():
# Keep a pointer so repeated Document.first() calls iterate across provided docs
session._doc_first_idx = 0
def _query_side_effect(model):
q = MagicMock()
def _get_entity(stmt) -> type | None:
"""Extract the mapped entity class from a SQLAlchemy select statement."""
try:
descs = stmt.column_descriptions
if descs:
return descs[0].get("entity")
except (AttributeError, TypeError):
pass
return None
# Capture filters passed via where(...) so first()/all() can honor them.
q._filters = {}
def _extract_id_from_where(stmt) -> str | None:
"""Return the value bound to the 'id' column in the WHERE clause, if present."""
try:
where = stmt.whereclause
if where is None:
return None
# Both single-clause and AND-clause-list cases
clauses = list(getattr(where, "clauses", [where]))
for clause in clauses:
left = getattr(clause, "left", None)
right = getattr(clause, "right", None)
if left is not None and right is not None:
if getattr(left, "key", None) == "id":
return getattr(right, "value", None)
except Exception:
pass
return None
def _extract_filters(*conds, **kw):
# Support both SQLAlchemy expressions (BinaryExpression) and kwargs
# We only need the simple fields used by production code: id, dataset_id, and id.in_(...)
for cond in conds:
left = getattr(cond, "left", None)
right = getattr(cond, "right", None)
key = None
if left is not None:
key = getattr(left, "key", None) or getattr(left, "name", None)
if not key:
continue
# Right side might be a BindParameter with .value, or a raw value/sequence
val = getattr(right, "value", right)
q._filters[key] = val
# Also accept kwargs (e.g., where(id=...)) just in case
for k, v in kw.items():
q._filters[k] = v
def _where_side_effect(*conds, **kw):
_extract_filters(*conds, **kw)
return q
q.where.side_effect = _where_side_effect
# Dataset queries
if model.__name__ == "Dataset":
def _dataset_first():
ds = session._shared_data.get("dataset")
if not ds:
return None
if "id" in q._filters:
val = q._filters["id"]
if isinstance(val, (list, tuple, set)):
return ds if ds.id in val else None
return ds if ds.id == val else None
return ds
def _dataset_all():
ds = session._shared_data.get("dataset")
if not ds:
return []
first = _dataset_first()
return [first] if first else []
q.first.side_effect = _dataset_first
q.all.side_effect = _dataset_all
return q
# Document queries
if model.__name__ == "Document":
def _apply_doc_filters(docs):
result = list(docs)
for key in ("id", "dataset_id"):
if key in q._filters:
val = q._filters[key]
if isinstance(val, (list, tuple, set)):
result = [d for d in result if getattr(d, key, None) in val]
else:
result = [d for d in result if getattr(d, key, None) == val]
return result
def _docs_all():
def _scalar_side_effect(stmt):
entity = _get_entity(stmt)
if entity is not None:
if entity.__name__ == "Dataset":
return session._shared_data.get("dataset")
elif entity.__name__ == "Document":
docs = session._shared_data.get("documents", [])
return _apply_doc_filters(docs)
if not docs:
return None
# When the WHERE clause filters by id, return the matching document
queried_id = _extract_id_from_where(stmt)
if queried_id:
doc_map = {d.id: d for d in docs}
return doc_map.get(queried_id, docs[0])
return docs[0]
return None
def _docs_first():
docs = _docs_all()
return docs[0] if docs else None
def _scalars_side_effect(stmt):
entity = _get_entity(stmt)
result = MagicMock()
if entity is not None:
if entity.__name__ == "Document":
result.all.return_value = list(session._shared_data.get("documents", []))
elif entity.__name__ == "Dataset":
ds = session._shared_data.get("dataset")
result.all.return_value = [ds] if ds else []
else:
result.all.return_value = []
else:
result.all.return_value = []
return result
q.all.side_effect = _docs_all
q.first.side_effect = _docs_first
return q
# Default fallback
q.first.return_value = None
q.all.return_value = []
return q
session.query.side_effect = _query_side_effect
session.scalar.side_effect = _scalar_side_effect
session.scalars.side_effect = _scalars_side_effect
# Implement session.begin() context manager that commits on exit
session.commit = MagicMock()
@ -638,8 +613,6 @@ class TestProgressTracking:
wrapper = TaskWrapper(data=next_task_data)
mock_redis.rpop.return_value = wrapper.serialize()
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -662,7 +635,6 @@ class TestProgressTracking:
"""
# Arrange
mock_redis.rpop.return_value = None # No more tasks
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -780,8 +752,7 @@ class TestErrorHandling:
If the dataset doesn't exist, the task should exit gracefully.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = None
# Arrange - dataset is not in _shared_data (None by default), so scalar() returns None
# Act
_document_indexing(dataset_id, document_ids)
@ -806,8 +777,6 @@ class TestErrorHandling:
# Set up rpop to return task once for concurrency check
mock_redis.rpop.side_effect = [wrapper.serialize(), None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Make _document_indexing raise an error
with patch("tasks.document_indexing_task._document_indexing") as mock_indexing:
mock_indexing.side_effect = Exception("Processing failed")
@ -844,7 +813,7 @@ class TestErrorHandling:
# Mock rpop to return tasks one by one
mock_redis.rpop.side_effect = tasks[:concurrency_limit] + [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_db_session._shared_data["dataset"] = mock_dataset
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit):
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
@ -977,7 +946,7 @@ class TestAdvancedScenarios:
# Mock rpop to return tasks up to concurrency limit
mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_db_session._shared_data["dataset"] = mock_dataset
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit):
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
@ -1070,7 +1039,7 @@ class TestAdvancedScenarios:
# Mock rpop to return tasks in FIFO order
mock_redis.rpop.side_effect = tasks + [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_db_session._shared_data["dataset"] = mock_dataset
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", 3):
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
@ -1108,7 +1077,7 @@ class TestAdvancedScenarios:
"""
# Arrange
mock_redis.rpop.return_value = None # Empty queue
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_db_session._shared_data["dataset"] = mock_dataset
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
# Act
@ -1276,7 +1245,7 @@ class TestIntegration:
# First call returns task 2, second call returns None
mock_redis.rpop.side_effect = [wrapper.serialize(), None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_db_session._shared_data["dataset"] = mock_dataset
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1433,7 +1402,7 @@ class TestPerformanceScenarios:
# Mock rpop to return tasks up to concurrency limit
mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_db_session._shared_data["dataset"] = mock_dataset
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit):
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
@ -1536,10 +1505,8 @@ class TestDocumentIndexingTaskSummaryFlow:
"""Test early return when dataset does not exist."""
# Arrange
session = MagicMock()
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = None
session.query.side_effect = lambda model: dataset_query
session = MagicMock()
session.scalar.return_value = None # dataset not found
create_session_mock = MagicMock(return_value=_SessionContext(session))
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
@ -1560,16 +1527,15 @@ class TestDocumentIndexingTaskSummaryFlow:
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
document = SimpleNamespace(id="doc-1", indexing_status=None, error=None, stopped_at=None)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.first.return_value = document
session = MagicMock()
session.query.side_effect = lambda model: dataset_query if model is Dataset else document_query
def _scalar_se(stmt):
entity = stmt.column_descriptions[0].get("entity")
if entity is Dataset:
return dataset
return document
session.scalar.side_effect = _scalar_se
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
@ -1643,9 +1609,12 @@ class TestDocumentIndexingTaskSummaryFlow:
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: phase1_document_query
session3.query.side_effect = lambda model: summary_document_query if model is Document else dataset_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=phase1_docs))
session3.scalar.return_value = dataset
session3.scalars.return_value = MagicMock(
all=MagicMock(return_value=[doc_eligible, doc_skip_form, doc_skip_status])
)
create_session_mock = MagicMock(
side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]
@ -1704,9 +1673,11 @@ class TestDocumentIndexingTaskSummaryFlow:
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: phase1_query
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
session3.scalar.return_value = dataset
session3.scalars.return_value = MagicMock(all=MagicMock(return_value=[doc_eligible]))
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
@ -1736,21 +1707,14 @@ class TestDocumentIndexingTaskSummaryFlow:
"""Test early return when dataset is missing after indexing."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.side_effect = [dataset, None]
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
session3.query.side_effect = lambda model: dataset_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
session3.scalar.return_value = None # dataset not found on second query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
@ -1770,7 +1734,7 @@ class TestDocumentIndexingTaskSummaryFlow:
_document_indexing("dataset-1", ["doc-1"])
# Assert
session3.query.assert_called()
session3.scalar.assert_called()
def test_should_skip_summary_when_not_high_quality(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test summary generation skipped when indexing_technique is not high_quality."""
@ -1781,21 +1745,14 @@ class TestDocumentIndexingTaskSummaryFlow:
indexing_technique="economy",
summary_index_setting={"enable": True},
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
session3.query.side_effect = lambda model: dataset_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
session3.scalar.return_value = dataset
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
@ -1824,19 +1781,12 @@ class TestDocumentIndexingTaskSummaryFlow:
"""Test summary generation is skipped when indexing is paused."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
create_session_mock = MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)])
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
@ -1865,19 +1815,12 @@ class TestDocumentIndexingTaskSummaryFlow:
"""Test generic indexing runner exception is handled."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
@ -1922,25 +1865,15 @@ class TestDocumentIndexingTaskSummaryFlow:
indexing_technique="high_quality",
summary_index_setting={"enable": True},
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
phase1_query = MagicMock()
phase1_query.where.return_value = phase1_query
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
summary_query = MagicMock()
summary_query.where.return_value = summary_query
summary_query.all.return_value = [_FalseyDocument("missing-doc")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: phase1_query
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
session3.scalar.return_value = dataset
session3.scalars.return_value = MagicMock(all=MagicMock(return_value=[_FalseyDocument("missing-doc")]))
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",

View File

@ -3,7 +3,6 @@ from unittest.mock import MagicMock, call, patch
import pytest
from libs.archive_storage import ArchiveStorageNotConfiguredError
from models.workflow import WorkflowArchiveLog
from tasks.remove_app_and_related_data_task import (
_delete_app_workflow_archive_logs,
_delete_archived_workflow_run_files,
@ -83,16 +82,11 @@ class TestDeleteWorkflowArchiveLogs:
assert params == {"tenant_id": tenant_id, "app_id": app_id}
assert name == "workflow archive log"
mock_query = MagicMock()
mock_delete_query = MagicMock()
mock_query.where.return_value = mock_delete_query
mock_db.session.query.return_value = mock_query
mock_session = MagicMock()
delete_func(mock_db.session, "log-1")
delete_func(mock_session, "log-1")
mock_db.session.query.assert_called_once_with(WorkflowArchiveLog)
mock_query.where.assert_called_once()
mock_delete_query.delete.assert_called_once_with(synchronize_session=False)
mock_session.execute.assert_called_once()
class TestDeleteArchivedWorkflowRunFiles:

6
api/uv.lock generated
View File

@ -4763,11 +4763,11 @@ wheels = [
[[package]]
name = "pypdf"
version = "6.9.2"
version = "6.10.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/31/83/691bdb309306232362503083cb15777491045dd54f45393a317dc7d8082f/pypdf-6.9.2.tar.gz", hash = "sha256:7f850faf2b0d4ab936582c05da32c52214c2b089d61a316627b5bfb5b0dab46c", size = 5311837, upload-time = "2026-03-23T14:53:27.983Z" }
sdist = { url = "https://files.pythonhosted.org/packages/b8/9f/ca96abf18683ca12602065e4ed2bec9050b672c87d317f1079abc7b6d993/pypdf-6.10.0.tar.gz", hash = "sha256:4c5a48ba258c37024ec2505f7e8fd858525f5502784a2e1c8d415604af29f6ef", size = 5314833, upload-time = "2026-04-10T09:34:57.102Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a5/7e/c85f41243086a8fe5d1baeba527cb26a1918158a565932b41e0f7c0b32e9/pypdf-6.9.2-py3-none-any.whl", hash = "sha256:662cf29bcb419a36a1365232449624ab40b7c2d0cfc28e54f42eeecd1fd7e844", size = 333744, upload-time = "2026-03-23T14:53:26.573Z" },
{ url = "https://files.pythonhosted.org/packages/55/f2/7ebe366f633f30a6ad105f650f44f24f98cb1335c4157d21ae47138b3482/pypdf-6.10.0-py3-none-any.whl", hash = "sha256:90005e959e1596c6e6c84c8b0ad383285b3e17011751cedd17f2ce8fcdfc86de", size = 334459, upload-time = "2026-04-10T09:34:54.966Z" },
]
[[package]]

View File

@ -27,7 +27,7 @@ describe('NumberInputField', () => {
it('should update value when users click increment', () => {
render(<NumberInputField label="Count" />)
fireEvent.click(screen.getByRole('button', { name: 'common.operation.increment' }))
fireEvent.click(screen.getByRole('button', { name: 'Increment value' }))
expect(mockField.handleChange).toHaveBeenCalledWith(3)
})

View File

@ -172,13 +172,13 @@ describe('NumberField wrapper', () => {
// Increment and decrement buttons should preserve accessible naming, icon fallbacks, and spacing variants.
describe('Control buttons', () => {
it('should provide localized aria labels and default icons when labels are not provided', () => {
it('should provide english fallback aria labels and default icons when labels are not provided', () => {
renderNumberField({
controlsProps: {},
})
const increment = screen.getByRole('button', { name: 'common.operation.increment' })
const decrement = screen.getByRole('button', { name: 'common.operation.decrement' })
const increment = screen.getByRole('button', { name: 'Increment value' })
const decrement = screen.getByRole('button', { name: 'Decrement value' })
expect(increment.querySelector('.i-ri-arrow-up-s-line')).toBeInTheDocument()
expect(decrement.querySelector('.i-ri-arrow-down-s-line')).toBeInTheDocument()
@ -217,11 +217,11 @@ describe('NumberField wrapper', () => {
},
})
expect(screen.getByRole('button', { name: 'common.operation.increment' })).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.decrement' })).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'Increment value' })).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'Decrement value' })).toBeInTheDocument()
})
it('should rely on aria-labelledby when provided instead of injecting a translated aria-label', () => {
it('should rely on aria-labelledby when provided instead of injecting a fallback aria-label', () => {
render(
<>
<span id="increment-label">Increment from label</span>

View File

@ -4,7 +4,6 @@ import type { VariantProps } from 'class-variance-authority'
import { NumberField as BaseNumberField } from '@base-ui/react/number-field'
import { cva } from 'class-variance-authority'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { cn } from '@/utils/classnames'
export const NumberField = BaseNumberField.Root
@ -188,18 +187,19 @@ type NumberFieldButtonVariantProps = Omit<
export type NumberFieldButtonProps = React.ComponentPropsWithoutRef<typeof BaseNumberField.Increment> & NumberFieldButtonVariantProps
const incrementAriaLabel = 'Increment value'
const decrementAriaLabel = 'Decrement value'
export function NumberFieldIncrement({
className,
children,
size = 'regular',
...props
}: NumberFieldButtonProps) {
const { t } = useTranslation()
return (
<BaseNumberField.Increment
{...props}
aria-label={props['aria-label'] ?? (props['aria-labelledby'] ? undefined : t('operation.increment', { ns: 'common' }))}
aria-label={props['aria-label'] ?? (props['aria-labelledby'] ? undefined : incrementAriaLabel)}
className={cn(numberFieldControlButtonVariants({ size, direction: 'increment' }), className)}
>
{children ?? <span aria-hidden="true" className="i-ri-arrow-up-s-line size-3" />}
@ -213,12 +213,10 @@ export function NumberFieldDecrement({
size = 'regular',
...props
}: NumberFieldButtonProps) {
const { t } = useTranslation()
return (
<BaseNumberField.Decrement
{...props}
aria-label={props['aria-label'] ?? (props['aria-labelledby'] ? undefined : t('operation.decrement', { ns: 'common' }))}
aria-label={props['aria-label'] ?? (props['aria-labelledby'] ? undefined : decrementAriaLabel)}
className={cn(numberFieldControlButtonVariants({ size, direction: 'decrement' }), className)}
>
{children ?? <span aria-hidden="true" className="i-ri-arrow-down-s-line size-3" />}

View File

@ -31,13 +31,13 @@ describe('base/ui/toast', () => {
expect(await screen.findByText('Saved')).toBeInTheDocument()
expect(screen.getByText('Your changes are available now.')).toBeInTheDocument()
const viewport = screen.getByRole('region', { name: 'common.toast.notifications' })
const viewport = screen.getByRole('region', { name: 'Notifications' })
expect(viewport).toHaveAttribute('aria-live', 'polite')
expect(viewport).toHaveClass('z-1101')
expect(viewport.firstElementChild).toHaveClass('top-4')
expect(screen.getByRole('dialog')).not.toHaveClass('outline-hidden')
expect(document.body.querySelector('[aria-hidden="true"].i-ri-checkbox-circle-fill')).toBeInTheDocument()
expect(document.body.querySelector('button[aria-label="common.toast.close"][aria-hidden="true"]')).toBeInTheDocument()
expect(document.body.querySelector('button[aria-label="Close notification"][aria-hidden="true"]')).toBeInTheDocument()
})
// Collapsed stacks should keep multiple toast roots mounted for smooth stack animation.
@ -57,12 +57,12 @@ describe('base/ui/toast', () => {
expect(await screen.findByText('Third toast')).toBeInTheDocument()
expect(screen.getAllByRole('dialog')).toHaveLength(3)
expect(document.body.querySelectorAll('button[aria-label="common.toast.close"][aria-hidden="true"]')).toHaveLength(3)
expect(document.body.querySelectorAll('button[aria-label="Close notification"][aria-hidden="true"]')).toHaveLength(3)
fireEvent.mouseEnter(screen.getByRole('region', { name: 'common.toast.notifications' }))
fireEvent.mouseEnter(screen.getByRole('region', { name: 'Notifications' }))
await waitFor(() => {
expect(document.body.querySelector('button[aria-label="common.toast.close"][aria-hidden="true"]')).not.toBeInTheDocument()
expect(document.body.querySelector('button[aria-label="Close notification"][aria-hidden="true"]')).not.toBeInTheDocument()
})
})
@ -126,9 +126,9 @@ describe('base/ui/toast', () => {
})
})
fireEvent.mouseEnter(screen.getByRole('region', { name: 'common.toast.notifications' }))
fireEvent.mouseEnter(screen.getByRole('region', { name: 'Notifications' }))
const dismissButton = await screen.findByRole('button', { name: 'common.toast.close' })
const dismissButton = await screen.findByRole('button', { name: 'Close notification' })
act(() => {
dismissButton.click()

View File

@ -7,7 +7,6 @@ import type {
} from '@base-ui/react/toast'
import type { ReactNode } from 'react'
import { Toast as BaseToast } from '@base-ui/react/toast'
import { useTranslation } from 'react-i18next'
import { cn } from '@/utils/classnames'
type ToastData = Record<string, never>
@ -35,6 +34,9 @@ const TOAST_TONE_STYLES = {
},
} satisfies Record<string, ToastToneStyle>
const toastCloseLabel = 'Close notification'
const toastViewportLabel = 'Notifications'
type ToastType = keyof typeof TOAST_TONE_STYLES
type ToastAddOptions = Omit<ToastManagerAddOptions<ToastData>, 'data' | 'positionerProps' | 'type'> & {
@ -145,7 +147,6 @@ function ToastCard({
}: {
toast: ToastObject<ToastData>
}) {
const { t } = useTranslation('common')
const toastType = getToastType(toastItem.type)
return (
@ -200,7 +201,7 @@ function ToastCard({
</div>
<div className="flex shrink-0 items-center justify-center rounded-md p-0.5">
<BaseToast.Close
aria-label={t('toast.close')}
aria-label={toastCloseLabel}
className={cn(
'flex h-5 w-5 items-center justify-center rounded-md hover:bg-state-base-hover focus-visible:bg-state-base-hover focus-visible:ring-1 focus-visible:ring-components-input-border-hover focus-visible:outline-hidden disabled:cursor-not-allowed disabled:opacity-50',
)}
@ -215,12 +216,11 @@ function ToastCard({
}
function ToastViewport() {
const { t } = useTranslation('common')
const { toasts } = BaseToast.useToastManager<ToastData>()
return (
<BaseToast.Viewport
aria-label={t('toast.notifications')}
aria-label={toastViewportLabel}
className={cn(
// During overlay migration, toast must stay above legacy highPriority modals (z-[1100]).
'inset-0 group/toast-viewport pointer-events-none fixed z-1101 overflow-visible',

View File

@ -0,0 +1,46 @@
import { render } from '@testing-library/react'
import { API_PREFIX } from '@/config'
import BlockIcon, { VarBlockIcon } from '../block-icon'
import { BlockEnum } from '../types'
describe('BlockIcon', () => {
it('renders the default workflow icon container for regular nodes', () => {
const { container } = render(<BlockIcon type={BlockEnum.Start} size="xs" className="extra-class" />)
const iconContainer = container.firstElementChild
expect(iconContainer).toHaveClass('w-4', 'h-4', 'bg-util-colors-blue-brand-blue-brand-500', 'extra-class')
expect(iconContainer?.querySelector('svg')).toBeInTheDocument()
})
it('normalizes protected plugin icon urls for tool-like nodes', () => {
const { container } = render(
<BlockIcon
type={BlockEnum.Tool}
toolIcon="/foo/workspaces/current/plugin/icon/plugin-tool.png"
/>,
)
const iconContainer = container.firstElementChild as HTMLElement
const backgroundIcon = iconContainer.querySelector('div') as HTMLElement
expect(iconContainer).not.toHaveClass('bg-util-colors-blue-blue-500')
expect(backgroundIcon.style.backgroundImage).toContain(
`${API_PREFIX}/workspaces/current/plugin/icon/plugin-tool.png`,
)
})
})
describe('VarBlockIcon', () => {
it('renders the compact icon variant without the default container wrapper', () => {
const { container } = render(
<VarBlockIcon
type={BlockEnum.Answer}
className="custom-var-icon"
/>,
)
expect(container.querySelector('.custom-var-icon')).toBeInTheDocument()
expect(container.querySelector('svg')).toBeInTheDocument()
expect(container.querySelector('.bg-util-colors-warning-warning-500')).not.toBeInTheDocument()
})
})

View File

@ -0,0 +1,39 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { WorkflowContextProvider } from '../context'
import { useStore, useWorkflowStore } from '../store'
const StoreConsumer = () => {
const showSingleRunPanel = useStore(s => s.showSingleRunPanel)
const store = useWorkflowStore()
return (
<button onClick={() => store.getState().setShowSingleRunPanel(!showSingleRunPanel)}>
{showSingleRunPanel ? 'open' : 'closed'}
</button>
)
}
describe('WorkflowContextProvider', () => {
it('provides the workflow store to descendants and keeps the same store across rerenders', async () => {
const user = userEvent.setup()
const { rerender } = render(
<WorkflowContextProvider>
<StoreConsumer />
</WorkflowContextProvider>,
)
expect(screen.getByRole('button', { name: 'closed' })).toBeInTheDocument()
await user.click(screen.getByRole('button', { name: 'closed' }))
expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument()
rerender(
<WorkflowContextProvider>
<StoreConsumer />
</WorkflowContextProvider>,
)
expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument()
})
})

View File

@ -0,0 +1,67 @@
import type { Edge, Node } from '../types'
import { render, screen } from '@testing-library/react'
import { useStoreApi } from 'reactflow'
import { useDatasetsDetailStore } from '../datasets-detail-store/store'
import WorkflowWithDefaultContext from '../index'
import { BlockEnum } from '../types'
import { useWorkflowHistoryStore } from '../workflow-history-store'
const nodes: Node[] = [
{
id: 'node-start',
type: 'custom',
position: { x: 0, y: 0 },
data: {
title: 'Start',
desc: '',
type: BlockEnum.Start,
},
},
]
const edges: Edge[] = [
{
id: 'edge-1',
source: 'node-start',
target: 'node-end',
sourceHandle: null,
targetHandle: null,
type: 'custom',
data: {
sourceType: BlockEnum.Start,
targetType: BlockEnum.End,
},
},
]
const ContextConsumer = () => {
const { store, shortcutsEnabled } = useWorkflowHistoryStore()
const datasetCount = useDatasetsDetailStore(state => Object.keys(state.datasetsDetail).length)
const reactFlowStore = useStoreApi()
return (
<div>
{`history:${store.getState().nodes.length}`}
{` shortcuts:${String(shortcutsEnabled)}`}
{` datasets:${datasetCount}`}
{` reactflow:${String(!!reactFlowStore)}`}
</div>
)
}
describe('WorkflowWithDefaultContext', () => {
it('wires the ReactFlow, workflow history, and datasets detail providers around its children', () => {
render(
<WorkflowWithDefaultContext
nodes={nodes}
edges={edges}
>
<ContextConsumer />
</WorkflowWithDefaultContext>,
)
expect(
screen.getByText('history:1 shortcuts:true datasets:0 reactflow:true'),
).toBeInTheDocument()
})
})

View File

@ -0,0 +1,51 @@
import { render, screen } from '@testing-library/react'
import ShortcutsName from '../shortcuts-name'
describe('ShortcutsName', () => {
const originalNavigator = globalThis.navigator
afterEach(() => {
Object.defineProperty(globalThis, 'navigator', {
value: originalNavigator,
writable: true,
configurable: true,
})
})
it('renders mac-friendly key labels and style variants', () => {
Object.defineProperty(globalThis, 'navigator', {
value: { userAgent: 'Macintosh' },
writable: true,
configurable: true,
})
const { container } = render(
<ShortcutsName
keys={['ctrl', 'shift', 's']}
bgColor="white"
textColor="secondary"
/>,
)
expect(screen.getByText('⌘')).toBeInTheDocument()
expect(screen.getByText('⇧')).toBeInTheDocument()
expect(screen.getByText('s')).toBeInTheDocument()
expect(container.querySelector('.system-kbd')).toHaveClass(
'bg-components-kbd-bg-white',
'text-text-tertiary',
)
})
it('keeps raw key names on non-mac systems', () => {
Object.defineProperty(globalThis, 'navigator', {
value: { userAgent: 'Windows NT' },
writable: true,
configurable: true,
})
render(<ShortcutsName keys={['ctrl', 'alt']} />)
expect(screen.getByText('ctrl')).toBeInTheDocument()
expect(screen.getByText('alt')).toBeInTheDocument()
})
})

View File

@ -0,0 +1,97 @@
import type { Edge, Node } from '../types'
import type { WorkflowHistoryState } from '../workflow-history-store'
import { render, renderHook, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { BlockEnum } from '../types'
import { useWorkflowHistoryStore, WorkflowHistoryProvider } from '../workflow-history-store'
const nodes: Node[] = [
{
id: 'node-1',
type: 'custom',
position: { x: 0, y: 0 },
data: {
title: 'Start',
desc: '',
type: BlockEnum.Start,
selected: true,
},
selected: true,
},
]
const edges: Edge[] = [
{
id: 'edge-1',
source: 'node-1',
target: 'node-2',
sourceHandle: null,
targetHandle: null,
type: 'custom',
selected: true,
data: {
sourceType: BlockEnum.Start,
targetType: BlockEnum.End,
},
},
]
const HistoryConsumer = () => {
const { store, shortcutsEnabled, setShortcutsEnabled } = useWorkflowHistoryStore()
return (
<button onClick={() => setShortcutsEnabled(!shortcutsEnabled)}>
{`nodes:${store.getState().nodes.length} shortcuts:${String(shortcutsEnabled)}`}
</button>
)
}
describe('WorkflowHistoryProvider', () => {
it('provides workflow history state and shortcut toggles', async () => {
const user = userEvent.setup()
render(
<WorkflowHistoryProvider
nodes={nodes}
edges={edges}
>
<HistoryConsumer />
</WorkflowHistoryProvider>,
)
expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' })).toBeInTheDocument()
await user.click(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' }))
expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:false' })).toBeInTheDocument()
})
it('sanitizes selected flags when history state is replaced through the exposed store api', () => {
const wrapper = ({ children }: { children: React.ReactNode }) => (
<WorkflowHistoryProvider
nodes={nodes}
edges={edges}
>
{children}
</WorkflowHistoryProvider>
)
const { result } = renderHook(() => useWorkflowHistoryStore(), { wrapper })
const nextState: WorkflowHistoryState = {
workflowHistoryEvent: undefined,
workflowHistoryEventMeta: undefined,
nodes,
edges,
}
result.current.store.setState(nextState)
expect(result.current.store.getState().nodes[0].data.selected).toBe(false)
expect(result.current.store.getState().edges[0].selected).toBe(false)
})
it('throws when consumed outside the provider', () => {
expect(() => renderHook(() => useWorkflowHistoryStore())).toThrow(
'useWorkflowHistoryStoreApi must be used within a WorkflowHistoryProvider',
)
})
})

View File

@ -0,0 +1,140 @@
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { useMarketplacePlugins } from '@/app/components/plugins/marketplace/hooks'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useGetLanguage } from '@/context/i18n'
import useTheme from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import AllTools from '../all-tools'
import { createGlobalPublicStoreState, createToolProvider } from './factories'
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: vi.fn(),
}))
vi.mock('@/context/i18n', () => ({
useGetLanguage: vi.fn(),
}))
vi.mock('@/hooks/use-theme', () => ({
default: vi.fn(),
}))
vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
useMarketplacePlugins: vi.fn(),
}))
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
useMCPToolAvailability: () => ({
allowed: true,
}),
}))
vi.mock('@/utils/var', async importOriginal => ({
...(await importOriginal<typeof import('@/utils/var')>()),
getMarketplaceUrl: () => 'https://marketplace.test/tools',
}))
const mockUseMarketplacePlugins = vi.mocked(useMarketplacePlugins)
const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore)
const mockUseGetLanguage = vi.mocked(useGetLanguage)
const mockUseTheme = vi.mocked(useTheme)
const createMarketplacePluginsMock = () => ({
plugins: [],
total: 0,
resetPlugins: vi.fn(),
queryPlugins: vi.fn(),
queryPluginsWithDebounced: vi.fn(),
cancelQueryPluginsWithDebounced: vi.fn(),
isLoading: false,
isFetchingNextPage: false,
hasNextPage: false,
fetchNextPage: vi.fn(),
page: 0,
})
describe('AllTools', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseGlobalPublicStore.mockImplementation(selector => selector(createGlobalPublicStoreState(false)))
mockUseGetLanguage.mockReturnValue('en_US')
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
mockUseMarketplacePlugins.mockReturnValue(createMarketplacePluginsMock())
})
it('filters tools by the active tab', async () => {
const user = userEvent.setup()
render(
<AllTools
searchText=""
tags={[]}
onSelect={vi.fn()}
buildInTools={[createToolProvider({
id: 'provider-built-in',
label: { en_US: 'Built In Provider', zh_Hans: 'Built In Provider' },
})]}
customTools={[createToolProvider({
id: 'provider-custom',
type: 'custom',
label: { en_US: 'Custom Provider', zh_Hans: 'Custom Provider' },
})]}
workflowTools={[]}
mcpTools={[]}
/>,
)
expect(screen.getByText('Built In Provider')).toBeInTheDocument()
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
await user.click(screen.getByText('workflow.tabs.customTool'))
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
expect(screen.queryByText('Built In Provider')).not.toBeInTheDocument()
})
it('filters the rendered tools by the search text', () => {
render(
<AllTools
searchText="report"
tags={[]}
onSelect={vi.fn()}
buildInTools={[
createToolProvider({
id: 'provider-report',
label: { en_US: 'Report Toolkit', zh_Hans: 'Report Toolkit' },
}),
createToolProvider({
id: 'provider-other',
label: { en_US: 'Other Toolkit', zh_Hans: 'Other Toolkit' },
}),
]}
customTools={[]}
workflowTools={[]}
mcpTools={[]}
/>,
)
expect(screen.getByText('Report Toolkit')).toBeInTheDocument()
expect(screen.queryByText('Other Toolkit')).not.toBeInTheDocument()
})
it('shows the empty state when no tool matches the current filter', async () => {
render(
<AllTools
searchText="missing"
tags={[]}
onSelect={vi.fn()}
buildInTools={[]}
customTools={[]}
workflowTools={[]}
mcpTools={[]}
/>,
)
await waitFor(() => {
expect(screen.getByText('workflow.tabs.noPluginsFound')).toBeInTheDocument()
})
})
})

View File

@ -1,51 +1,110 @@
import type { NodeDefault } from '../../types'
import { render, screen } from '@testing-library/react'
import { useStoreApi } from 'reactflow'
import userEvent from '@testing-library/user-event'
import { AppTypeEnum } from '@/types/app'
import { BlockEnum } from '../../types'
import Blocks from '../blocks'
import { BlockClassificationEnum } from '../types'
const mockGetNodes = vi.fn(() => [])
const mockAppType = vi.hoisted<{ current?: string }>(() => ({
current: 'workflow',
const runtimeState = vi.hoisted(() => ({
appType: 'workflow' as string | undefined,
nodes: [] as Array<{ data: { type?: BlockEnum } }>,
}))
vi.mock('reactflow', () => ({
useStoreApi: vi.fn(),
useStoreApi: () => ({
getState: () => ({
getNodes: () => runtimeState.nodes,
}),
}),
}))
vi.mock('@/app/components/app/store', () => ({
useStore: (selector: (state: { appDetail: { type?: string } }) => unknown) => selector({
appDetail: {
type: mockAppType.current,
type: runtimeState.appType,
},
}),
}))
const mockUseStoreApi = vi.mocked(useStoreApi)
const createBlock = (type: BlockEnum, title: string, classification = BlockClassificationEnum.Default): NodeDefault => ({
metaData: {
classification,
sort: 0,
type,
title,
author: 'Dify',
description: `${title} description`,
},
defaultValue: {},
checkValid: () => ({ isValid: true }),
})
describe('Blocks', () => {
beforeEach(() => {
vi.clearAllMocks()
mockAppType.current = AppTypeEnum.WORKFLOW
mockUseStoreApi.mockReturnValue({
getState: () => ({
getNodes: mockGetNodes,
}),
} as unknown as ReturnType<typeof useStoreApi>)
runtimeState.appType = AppTypeEnum.WORKFLOW
runtimeState.nodes = []
})
it('should hide human input in evaluation workflows', () => {
mockAppType.current = AppTypeEnum.EVALUATION
it('should hide human input blocks when the app is an evaluation workflow', () => {
runtimeState.appType = AppTypeEnum.EVALUATION
render(
<Blocks
searchText=""
onSelect={vi.fn()}
availableBlocksTypes={[BlockEnum.HumanInput, BlockEnum.LLM]}
blocks={[
createBlock(BlockEnum.HumanInput, 'Human Input'),
createBlock(BlockEnum.LLM, 'LLM'),
]}
/>,
)
expect(screen.queryByText('workflow.blocks.human-input')).not.toBeInTheDocument()
expect(screen.getByText('workflow.blocks.llm')).toBeInTheDocument()
expect(screen.queryByText('Human Input')).not.toBeInTheDocument()
expect(screen.getByText('LLM')).toBeInTheDocument()
})
it('should render grouped blocks, filter duplicate knowledge-base nodes, and select a block', async () => {
const user = userEvent.setup()
const onSelect = vi.fn()
runtimeState.nodes = [{ data: { type: BlockEnum.KnowledgeBase } }]
render(
<Blocks
searchText=""
onSelect={onSelect}
availableBlocksTypes={[BlockEnum.LLM, BlockEnum.LoopEnd, BlockEnum.KnowledgeBase]}
blocks={[
createBlock(BlockEnum.LLM, 'LLM'),
createBlock(BlockEnum.LoopEnd, 'Exit Loop', BlockClassificationEnum.Logic),
createBlock(BlockEnum.KnowledgeBase, 'Knowledge Retrieval'),
]}
/>,
)
expect(screen.getByText('LLM')).toBeInTheDocument()
expect(screen.getByText('Exit Loop')).toBeInTheDocument()
expect(screen.getByText('workflow.nodes.loop.loopNode')).toBeInTheDocument()
expect(screen.queryByText('Knowledge Retrieval')).not.toBeInTheDocument()
await user.click(screen.getByText('LLM'))
expect(onSelect).toHaveBeenCalledWith(BlockEnum.LLM)
})
it('should show the empty state when no block matches the search text', () => {
render(
<Blocks
searchText="missing"
onSelect={vi.fn()}
availableBlocksTypes={[BlockEnum.LLM]}
blocks={[createBlock(BlockEnum.LLM, 'LLM')]}
/>,
)
expect(screen.getByText('workflow.tabs.noResult')).toBeInTheDocument()
})
})

View File

@ -0,0 +1,101 @@
import type { ToolWithProvider } from '../../types'
import type { Plugin } from '@/app/components/plugins/types'
import type { Tool } from '@/app/components/tools/types'
import { PluginCategoryEnum } from '@/app/components/plugins/types'
import { CollectionType } from '@/app/components/tools/types'
import { defaultSystemFeatures } from '@/types/feature'
export const createTool = (
name: string,
label: string,
description = `${label} description`,
): Tool => ({
name,
author: 'author',
label: {
en_US: label,
zh_Hans: label,
},
description: {
en_US: description,
zh_Hans: description,
},
parameters: [],
labels: [],
output_schema: {},
})
export const createToolProvider = (
overrides: Partial<ToolWithProvider> = {},
): ToolWithProvider => ({
id: 'provider-1',
name: 'provider-one',
author: 'Provider Author',
description: {
en_US: 'Provider description',
zh_Hans: 'Provider description',
},
icon: 'icon',
icon_dark: 'icon-dark',
label: {
en_US: 'Provider One',
zh_Hans: 'Provider One',
},
type: CollectionType.builtIn,
team_credentials: {},
is_team_authorization: false,
allow_delete: false,
labels: [],
plugin_id: 'plugin-1',
tools: [createTool('tool-a', 'Tool A')],
meta: { version: '1.0.0' } as ToolWithProvider['meta'],
plugin_unique_identifier: 'plugin-1@1.0.0',
...overrides,
})
export const createPlugin = (overrides: Partial<Plugin> = {}): Plugin => ({
type: 'plugin',
org: 'org',
author: 'author',
name: 'Plugin One',
plugin_id: 'plugin-1',
version: '1.0.0',
latest_version: '1.0.0',
latest_package_identifier: 'plugin-1@1.0.0',
icon: 'icon',
verified: true,
label: {
en_US: 'Plugin One',
zh_Hans: 'Plugin One',
},
brief: {
en_US: 'Plugin description',
zh_Hans: 'Plugin description',
},
description: {
en_US: 'Plugin description',
zh_Hans: 'Plugin description',
},
introduction: 'Plugin introduction',
repository: 'https://example.com/plugin',
category: PluginCategoryEnum.tool,
tags: [],
badges: [],
install_count: 0,
endpoint: {
settings: [],
},
verification: {
authorized_category: 'community',
},
from: 'github',
...overrides,
})
export const createGlobalPublicStoreState = (enableMarketplace: boolean) => ({
systemFeatures: {
...defaultSystemFeatures,
enable_marketplace: enableMarketplace,
},
setSystemFeatures: vi.fn(),
})

View File

@ -0,0 +1,101 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { useGetLanguage } from '@/context/i18n'
import useTheme from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import FeaturedTools from '../featured-tools'
import { createPlugin, createToolProvider } from './factories'
vi.mock('@/context/i18n', () => ({
useGetLanguage: vi.fn(),
}))
vi.mock('@/hooks/use-theme', () => ({
default: vi.fn(),
}))
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
useMCPToolAvailability: () => ({
allowed: true,
}),
}))
vi.mock('@/utils/var', async importOriginal => ({
...(await importOriginal<typeof import('@/utils/var')>()),
getMarketplaceUrl: () => 'https://marketplace.test/tools',
}))
const mockUseGetLanguage = vi.mocked(useGetLanguage)
const mockUseTheme = vi.mocked(useTheme)
describe('FeaturedTools', () => {
beforeEach(() => {
vi.clearAllMocks()
localStorage.clear()
mockUseGetLanguage.mockReturnValue('en_US')
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
})
it('shows more featured tools when the list exceeds the initial quota', async () => {
const user = userEvent.setup()
const plugins = Array.from({ length: 6 }, (_, index) =>
createPlugin({
plugin_id: `plugin-${index + 1}`,
latest_package_identifier: `plugin-${index + 1}@1.0.0`,
label: { en_US: `Plugin ${index + 1}`, zh_Hans: `Plugin ${index + 1}` },
}))
const providers = plugins.map((plugin, index) =>
createToolProvider({
id: `provider-${index + 1}`,
plugin_id: plugin.plugin_id,
label: { en_US: `Provider ${index + 1}`, zh_Hans: `Provider ${index + 1}` },
}),
)
const providerMap = new Map(providers.map(provider => [provider.plugin_id!, provider]))
render(
<FeaturedTools
plugins={plugins}
providerMap={providerMap}
onSelect={vi.fn()}
/>,
)
expect(screen.getByText('Provider 1')).toBeInTheDocument()
expect(screen.queryByText('Provider 6')).not.toBeInTheDocument()
await user.click(screen.getByText('workflow.tabs.showMoreFeatured'))
expect(screen.getByText('Provider 6')).toBeInTheDocument()
})
it('honors the persisted collapsed state', () => {
localStorage.setItem('workflow_tools_featured_collapsed', 'true')
render(
<FeaturedTools
plugins={[createPlugin()]}
providerMap={new Map([[
'plugin-1',
createToolProvider(),
]])}
onSelect={vi.fn()}
/>,
)
expect(screen.getByText('workflow.tabs.featuredTools')).toBeInTheDocument()
expect(screen.queryByText('Provider One')).not.toBeInTheDocument()
})
it('shows the marketplace empty state when no featured tools are available', () => {
render(
<FeaturedTools
plugins={[]}
providerMap={new Map()}
onSelect={vi.fn()}
/>,
)
expect(screen.getByText('workflow.tabs.noFeaturedPlugins')).toBeInTheDocument()
})
})

View File

@ -0,0 +1,52 @@
import { act, renderHook } from '@testing-library/react'
import { useTabs, useToolTabs } from '../hooks'
import { TabsEnum, ToolTypeEnum } from '../types'
describe('block-selector hooks', () => {
it('falls back to the first valid tab when the preferred start tab is disabled', () => {
const { result } = renderHook(() => useTabs({
noStart: false,
hasUserInputNode: true,
defaultActiveTab: TabsEnum.Start,
}))
expect(result.current.tabs.find(tab => tab.key === TabsEnum.Start)?.disabled).toBe(true)
expect(result.current.activeTab).toBe(TabsEnum.Blocks)
})
it('keeps the start tab enabled when forcing it on and resets to a valid tab after disabling blocks', () => {
const props: Parameters<typeof useTabs>[0] = {
noBlocks: false,
noStart: false,
hasUserInputNode: true,
forceEnableStartTab: true,
}
const { result, rerender } = renderHook(nextProps => useTabs(nextProps), {
initialProps: props,
})
expect(result.current.tabs.find(tab => tab.key === TabsEnum.Start)?.disabled).toBeFalsy()
act(() => {
result.current.setActiveTab(TabsEnum.Blocks)
})
rerender({
...props,
noBlocks: true,
noSources: true,
noTools: true,
})
expect(result.current.activeTab).toBe(TabsEnum.Start)
})
it('returns the MCP tab only when it is not hidden', () => {
const { result: visible } = renderHook(() => useToolTabs())
const { result: hidden } = renderHook(() => useToolTabs(true))
expect(visible.current.some(tab => tab.key === ToolTypeEnum.MCP)).toBe(true)
expect(hidden.current.some(tab => tab.key === ToolTypeEnum.MCP)).toBe(false)
})
})

View File

@ -0,0 +1,90 @@
import type { NodeDefault, ToolWithProvider } from '../../types'
import { screen } from '@testing-library/react'
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
import { BlockEnum } from '../../types'
import NodeSelectorWrapper from '../index'
import { BlockClassificationEnum } from '../types'
vi.mock('reactflow', async () =>
(await import('../../__tests__/reactflow-mock-state')).createReactFlowModuleMock())
vi.mock('@/service/use-plugins', () => ({
useFeaturedToolsRecommendations: () => ({
plugins: [],
isLoading: false,
}),
}))
vi.mock('@/service/use-tools', () => ({
useAllBuiltInTools: () => ({ data: [] }),
useAllCustomTools: () => ({ data: [] }),
useAllWorkflowTools: () => ({ data: [] }),
useAllMCPTools: () => ({ data: [] }),
useInvalidateAllBuiltInTools: () => vi.fn(),
}))
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_marketplace: boolean } }) => unknown) => selector({
systemFeatures: { enable_marketplace: false },
}),
}))
const createBlock = (type: BlockEnum, title: string): NodeDefault => ({
metaData: {
type,
title,
sort: 0,
classification: BlockClassificationEnum.Default,
author: 'Dify',
description: `${title} description`,
},
defaultValue: {},
checkValid: () => ({ isValid: true }),
})
const dataSource: ToolWithProvider = {
id: 'datasource-1',
name: 'datasource',
author: 'Dify',
description: { en_US: 'Data source', zh_Hans: '数据源' },
icon: 'icon',
label: { en_US: 'Data Source', zh_Hans: 'Data Source' },
type: 'datasource' as ToolWithProvider['type'],
team_credentials: {},
is_team_authorization: false,
allow_delete: false,
labels: [],
tools: [],
meta: { version: '1.0.0' } as ToolWithProvider['meta'],
}
describe('NodeSelectorWrapper', () => {
it('filters hidden block types from hooks store and forwards data sources', async () => {
renderWorkflowComponent(
<NodeSelectorWrapper
open
onSelect={vi.fn()}
availableBlocksTypes={[BlockEnum.Code]}
/>,
{
hooksStoreProps: {
availableNodesMetaData: {
nodes: [
createBlock(BlockEnum.Start, 'Start'),
createBlock(BlockEnum.Tool, 'Tool'),
createBlock(BlockEnum.Code, 'Code'),
createBlock(BlockEnum.DataSource, 'Data Source'),
],
},
},
initialStoreState: {
dataSourceList: [dataSource],
},
},
)
expect(await screen.findByText('Code')).toBeInTheDocument()
expect(screen.queryByText('Start')).not.toBeInTheDocument()
expect(screen.queryByText('Tool')).not.toBeInTheDocument()
})
})

View File

@ -0,0 +1,95 @@
import type { NodeDefault } from '../../types'
import { screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
import { BlockEnum } from '../../types'
import NodeSelector from '../main'
import { BlockClassificationEnum } from '../types'
vi.mock('reactflow', () => ({
useStoreApi: () => ({
getState: () => ({
getNodes: () => [],
}),
}),
}))
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_marketplace: boolean } }) => unknown) => selector({
systemFeatures: { enable_marketplace: false },
}),
}))
vi.mock('@/service/use-plugins', () => ({
useFeaturedToolsRecommendations: () => ({
plugins: [],
isLoading: false,
}),
}))
vi.mock('@/service/use-tools', () => ({
useAllBuiltInTools: () => ({ data: [] }),
useAllCustomTools: () => ({ data: [] }),
useAllWorkflowTools: () => ({ data: [] }),
useAllMCPTools: () => ({ data: [] }),
useInvalidateAllBuiltInTools: () => vi.fn(),
}))
const createBlock = (type: BlockEnum, title: string): NodeDefault => ({
metaData: {
classification: BlockClassificationEnum.Default,
sort: 0,
type,
title,
author: 'Dify',
description: `${title} description`,
},
defaultValue: {},
checkValid: () => ({ isValid: true }),
})
describe('NodeSelector', () => {
it('opens with the real blocks tab, filters by search, selects a block, and clears search after close', async () => {
const user = userEvent.setup()
const onSelect = vi.fn()
renderWorkflowComponent(
<NodeSelector
onSelect={onSelect}
blocks={[
createBlock(BlockEnum.LLM, 'LLM'),
createBlock(BlockEnum.End, 'End'),
]}
availableBlocksTypes={[BlockEnum.LLM, BlockEnum.End]}
trigger={open => (
<button type="button">
{open ? 'selector-open' : 'selector-closed'}
</button>
)}
/>,
)
await user.click(screen.getByRole('button', { name: 'selector-closed' }))
const searchInput = screen.getByPlaceholderText('workflow.tabs.searchBlock')
expect(screen.getByText('LLM')).toBeInTheDocument()
expect(screen.getByText('End')).toBeInTheDocument()
await user.type(searchInput, 'LLM')
expect(screen.getByText('LLM')).toBeInTheDocument()
expect(screen.queryByText('End')).not.toBeInTheDocument()
await user.click(screen.getByText('LLM'))
expect(onSelect).toHaveBeenCalledWith(BlockEnum.LLM, undefined)
await waitFor(() => {
expect(screen.queryByPlaceholderText('workflow.tabs.searchBlock')).not.toBeInTheDocument()
})
await user.click(screen.getByRole('button', { name: 'selector-closed' }))
const reopenedInput = screen.getByPlaceholderText('workflow.tabs.searchBlock') as HTMLInputElement
expect(reopenedInput.value).toBe('')
expect(screen.getByText('End')).toBeInTheDocument()
})
})

View File

@ -0,0 +1,95 @@
import { render, screen } from '@testing-library/react'
import { CollectionType } from '@/app/components/tools/types'
import { useGetLanguage } from '@/context/i18n'
import useTheme from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import Tools from '../tools'
import { ViewType } from '../view-type-select'
import { createToolProvider } from './factories'
vi.mock('@/context/i18n', () => ({
useGetLanguage: vi.fn(),
}))
vi.mock('@/hooks/use-theme', () => ({
default: vi.fn(),
}))
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
useMCPToolAvailability: () => ({
allowed: true,
}),
}))
const mockUseGetLanguage = vi.mocked(useGetLanguage)
const mockUseTheme = vi.mocked(useTheme)
describe('Tools', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseGetLanguage.mockReturnValue('en_US')
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
})
it('shows the empty state when there are no tools and no search text', () => {
render(
<Tools
tools={[]}
onSelect={vi.fn()}
viewType={ViewType.flat}
hasSearchText={false}
/>,
)
expect(screen.getByText('No tools available')).toBeInTheDocument()
})
it('renders tree groups for built-in and custom providers', () => {
render(
<Tools
tools={[
createToolProvider({
id: 'built-in-provider',
author: 'Built In',
label: { en_US: 'Built In Provider', zh_Hans: 'Built In Provider' },
}),
createToolProvider({
id: 'custom-provider',
type: CollectionType.custom,
label: { en_US: 'Custom Provider', zh_Hans: 'Custom Provider' },
}),
]}
onSelect={vi.fn()}
viewType={ViewType.tree}
hasSearchText={false}
/>,
)
expect(screen.getByText('Built In')).toBeInTheDocument()
expect(screen.getByText('workflow.tabs.customTool')).toBeInTheDocument()
expect(screen.getByText('Built In Provider')).toBeInTheDocument()
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
})
it('shows the alphabetical index in flat view when enough tools are present', () => {
const { container } = render(
<Tools
tools={Array.from({ length: 11 }, (_, index) =>
createToolProvider({
id: `provider-${index}`,
label: {
en_US: `${String.fromCharCode(65 + index)} Provider`,
zh_Hans: `${String.fromCharCode(65 + index)} Provider`,
},
}))}
onSelect={vi.fn()}
viewType={ViewType.flat}
hasSearchText={false}
/>,
)
expect(container.querySelector('.index-bar')).toBeInTheDocument()
expect(screen.getByText('A Provider')).toBeInTheDocument()
expect(screen.getByText('K Provider')).toBeInTheDocument()
})
})

View File

@ -0,0 +1,99 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { trackEvent } from '@/app/components/base/amplitude'
import { CollectionType } from '@/app/components/tools/types'
import { useGetLanguage } from '@/context/i18n'
import useTheme from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import { BlockEnum } from '../../../types'
import { createTool, createToolProvider } from '../../__tests__/factories'
import { ViewType } from '../../view-type-select'
import Tool from '../tool'
vi.mock('@/context/i18n', () => ({
useGetLanguage: vi.fn(),
}))
vi.mock('@/hooks/use-theme', () => ({
default: vi.fn(),
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: vi.fn(),
}))
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
useMCPToolAvailability: () => ({
allowed: true,
}),
}))
const mockUseGetLanguage = vi.mocked(useGetLanguage)
const mockUseTheme = vi.mocked(useTheme)
const mockTrackEvent = vi.mocked(trackEvent)
describe('Tool', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseGetLanguage.mockReturnValue('en_US')
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
})
it('expands a provider and selects an action item', async () => {
const user = userEvent.setup()
const onSelect = vi.fn()
render(
<Tool
payload={createToolProvider({
tools: [
createTool('tool-a', 'Tool A'),
createTool('tool-b', 'Tool B'),
],
})}
viewType={ViewType.flat}
hasSearchText={false}
onSelect={onSelect}
/>,
)
await user.click(screen.getByText('Provider One'))
await user.click(screen.getByText('Tool B'))
expect(onSelect).toHaveBeenCalledWith(BlockEnum.Tool, expect.objectContaining({
provider_id: 'provider-1',
provider_name: 'provider-one',
tool_name: 'tool-b',
title: 'Tool B',
}))
expect(mockTrackEvent).toHaveBeenCalledWith('tool_selected', {
tool_name: 'tool-b',
plugin_id: 'plugin-1',
})
})
it('selects workflow tools directly without expanding the provider', async () => {
const user = userEvent.setup()
const onSelect = vi.fn()
render(
<Tool
payload={createToolProvider({
type: CollectionType.workflow,
tools: [createTool('workflow-tool', 'Workflow Tool')],
})}
viewType={ViewType.flat}
hasSearchText={false}
onSelect={onSelect}
/>,
)
await user.click(screen.getByText('Workflow Tool'))
expect(onSelect).toHaveBeenCalledWith(BlockEnum.Tool, expect.objectContaining({
provider_type: CollectionType.workflow,
tool_name: 'workflow-tool',
tool_label: 'Workflow Tool',
}))
})
})

View File

@ -0,0 +1,66 @@
import { render, screen } from '@testing-library/react'
import { useGetLanguage } from '@/context/i18n'
import useTheme from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import { createToolProvider } from '../../../__tests__/factories'
import List from '../list'
vi.mock('@/context/i18n', () => ({
useGetLanguage: vi.fn(),
}))
vi.mock('@/hooks/use-theme', () => ({
default: vi.fn(),
}))
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
useMCPToolAvailability: () => ({
allowed: true,
}),
}))
const mockUseGetLanguage = vi.mocked(useGetLanguage)
const mockUseTheme = vi.mocked(useTheme)
describe('ToolListFlatView', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseGetLanguage.mockReturnValue('en_US')
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
})
it('assigns the first tool of each letter to the shared refs and renders the index bar', () => {
const toolRefs = {
current: {} as Record<string, HTMLDivElement | null>,
}
render(
<List
letters={['A', 'B']}
payload={[
createToolProvider({
id: 'provider-a',
label: { en_US: 'A Provider', zh_Hans: 'A Provider' },
letter: 'A',
} as ReturnType<typeof createToolProvider>),
createToolProvider({
id: 'provider-b',
label: { en_US: 'B Provider', zh_Hans: 'B Provider' },
letter: 'B',
} as ReturnType<typeof createToolProvider>),
]}
isShowLetterIndex
indexBar={<div data-testid="index-bar" />}
hasSearchText={false}
onSelect={vi.fn()}
toolRefs={toolRefs}
/>,
)
expect(screen.getByText('A Provider')).toBeInTheDocument()
expect(screen.getByText('B Provider')).toBeInTheDocument()
expect(screen.getByTestId('index-bar')).toBeInTheDocument()
expect(toolRefs.current.A).toBeTruthy()
expect(toolRefs.current.B).toBeTruthy()
})
})

View File

@ -0,0 +1,47 @@
import { render, screen } from '@testing-library/react'
import { useGetLanguage } from '@/context/i18n'
import useTheme from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import { createToolProvider } from '../../../__tests__/factories'
import Item from '../item'
vi.mock('@/context/i18n', () => ({
useGetLanguage: vi.fn(),
}))
vi.mock('@/hooks/use-theme', () => ({
default: vi.fn(),
}))
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
useMCPToolAvailability: () => ({
allowed: true,
}),
}))
const mockUseGetLanguage = vi.mocked(useGetLanguage)
const mockUseTheme = vi.mocked(useTheme)
describe('ToolListTreeView Item', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseGetLanguage.mockReturnValue('en_US')
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
})
it('renders the group heading and its provider list', () => {
render(
<Item
groupName="My Group"
toolList={[createToolProvider({
label: { en_US: 'Provider Alpha', zh_Hans: 'Provider Alpha' },
})]}
hasSearchText={false}
onSelect={vi.fn()}
/>,
)
expect(screen.getByText('My Group')).toBeInTheDocument()
expect(screen.getByText('Provider Alpha')).toBeInTheDocument()
})
})

View File

@ -0,0 +1,56 @@
import { render, screen } from '@testing-library/react'
import { useGetLanguage } from '@/context/i18n'
import useTheme from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import { createToolProvider } from '../../../__tests__/factories'
import { CUSTOM_GROUP_NAME } from '../../../index-bar'
import List from '../list'
vi.mock('@/context/i18n', () => ({
useGetLanguage: vi.fn(),
}))
vi.mock('@/hooks/use-theme', () => ({
default: vi.fn(),
}))
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
useMCPToolAvailability: () => ({
allowed: true,
}),
}))
const mockUseGetLanguage = vi.mocked(useGetLanguage)
const mockUseTheme = vi.mocked(useTheme)
describe('ToolListTreeView', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseGetLanguage.mockReturnValue('en_US')
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
})
it('translates built-in special group names and renders the nested providers', () => {
render(
<List
payload={{
BuiltIn: [createToolProvider({
label: { en_US: 'Built In Provider', zh_Hans: 'Built In Provider' },
})],
[CUSTOM_GROUP_NAME]: [createToolProvider({
id: 'custom-provider',
type: 'custom',
label: { en_US: 'Custom Provider', zh_Hans: 'Custom Provider' },
})],
}}
hasSearchText={false}
onSelect={vi.fn()}
/>,
)
expect(screen.getByText('BuiltIn')).toBeInTheDocument()
expect(screen.getByText('workflow.tabs.customTool')).toBeInTheDocument()
expect(screen.getByText('Built In Provider')).toBeInTheDocument()
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
})
})

View File

@ -0,0 +1,91 @@
import type { DataSet } from '@/models/datasets'
import { renderHook } from '@testing-library/react'
import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets'
import { DatasetsDetailContext } from '../provider'
import { createDatasetsDetailStore, useDatasetsDetailStore } from '../store'
const createDataset = (id: string, name = `dataset-${id}`): DataSet => ({
id,
name,
indexing_status: 'completed',
icon_info: {
icon: 'book',
icon_type: 'emoji' as DataSet['icon_info']['icon_type'],
},
description: `${name} description`,
permission: DatasetPermission.onlyMe,
data_source_type: DataSourceType.FILE,
indexing_technique: 'high_quality' as DataSet['indexing_technique'],
created_by: 'user-1',
updated_by: 'user-1',
updated_at: 1,
app_count: 0,
doc_form: ChunkingMode.text,
document_count: 0,
total_document_count: 0,
word_count: 0,
provider: 'provider',
embedding_model: 'model',
embedding_model_provider: 'provider',
embedding_available: true,
retrieval_model_dict: {} as DataSet['retrieval_model_dict'],
retrieval_model: {} as DataSet['retrieval_model'],
tags: [],
external_knowledge_info: {
external_knowledge_id: '',
external_knowledge_api_id: '',
external_knowledge_api_name: '',
external_knowledge_api_endpoint: '',
},
external_retrieval_model: {
top_k: 1,
score_threshold: 0,
score_threshold_enabled: false,
},
built_in_field_enabled: false,
runtime_mode: 'general',
enable_api: false,
is_multimodal: false,
})
describe('datasets-detail-store store', () => {
it('merges dataset details by id', () => {
const store = createDatasetsDetailStore()
store.getState().updateDatasetsDetail([
createDataset('dataset-1', 'Dataset One'),
createDataset('dataset-2', 'Dataset Two'),
])
store.getState().updateDatasetsDetail([
createDataset('dataset-2', 'Dataset Two Updated'),
])
expect(store.getState().datasetsDetail).toMatchObject({
'dataset-1': { name: 'Dataset One' },
'dataset-2': { name: 'Dataset Two Updated' },
})
})
it('reads state from the datasets detail context', () => {
const store = createDatasetsDetailStore()
store.getState().updateDatasetsDetail([createDataset('dataset-3')])
const wrapper = ({ children }: { children: React.ReactNode }) => (
<DatasetsDetailContext.Provider value={store}>
{children}
</DatasetsDetailContext.Provider>
)
const { result } = renderHook(
() => useDatasetsDetailStore(state => state.datasetsDetail['dataset-3']?.name),
{ wrapper },
)
expect(result.current).toBe('dataset-dataset-3')
})
it('throws when the datasets detail provider is missing', () => {
expect(() => renderHook(() => useDatasetsDetailStore(state => state.datasetsDetail))).toThrow(
'Missing DatasetsDetailContext.Provider in the tree',
)
})
})

View File

@ -0,0 +1,41 @@
import { renderHook } from '@testing-library/react'
import { HooksStoreContext } from '../provider'
import { createHooksStore, useHooksStore } from '../store'
describe('hooks-store store', () => {
it('creates default callbacks and refreshes selected handlers', () => {
const store = createHooksStore({})
const handleBackupDraft = vi.fn()
expect(store.getState().availableNodesMetaData).toEqual({ nodes: [] })
expect(store.getState().hasNodeInspectVars('node-1')).toBe(false)
expect(store.getState().getWorkflowRunAndTraceUrl('run-1')).toEqual({
runUrl: '',
traceUrl: '',
})
store.getState().refreshAll({ handleBackupDraft })
expect(store.getState().handleBackupDraft).toBe(handleBackupDraft)
})
it('reads state from the hooks store context', () => {
const handleRun = vi.fn()
const store = createHooksStore({ handleRun })
const wrapper = ({ children }: { children: React.ReactNode }) => (
<HooksStoreContext.Provider value={store}>
{children}
</HooksStoreContext.Provider>
)
const { result } = renderHook(() => useHooksStore(state => state.handleRun), { wrapper })
expect(result.current).toBe(handleRun)
})
it('throws when the hooks store provider is missing', () => {
expect(() => renderHook(() => useHooksStore(state => state.handleRun))).toThrow(
'Missing HooksStoreContext.Provider in the tree',
)
})
})

View File

@ -0,0 +1,19 @@
import { renderWorkflowHook } from '../../__tests__/workflow-test-env'
import { useDSL } from '../use-DSL'
describe('useDSL', () => {
it('returns the DSL handlers from hooks store', () => {
const exportCheck = vi.fn()
const handleExportDSL = vi.fn()
const { result } = renderWorkflowHook(() => useDSL(), {
hooksStoreProps: {
exportCheck,
handleExportDSL,
},
})
expect(result.current.exportCheck).toBe(exportCheck)
expect(result.current.handleExportDSL).toBe(handleExportDSL)
})
})

View File

@ -0,0 +1,90 @@
import { act, waitFor } from '@testing-library/react'
import { useEdges } from 'reactflow'
import { createEdge, createNode } from '../../__tests__/fixtures'
import { renderWorkflowFlowHook } from '../../__tests__/workflow-test-env'
import { NodeRunningStatus } from '../../types'
import { useEdgesInteractionsWithoutSync } from '../use-edges-interactions-without-sync'
type EdgeRuntimeState = {
_sourceRunningStatus?: NodeRunningStatus
_targetRunningStatus?: NodeRunningStatus
_waitingRun?: boolean
}
const getEdgeRuntimeState = (edge?: { data?: unknown }): EdgeRuntimeState =>
(edge?.data ?? {}) as EdgeRuntimeState
const createFlowNodes = () => [
createNode({ id: 'a' }),
createNode({ id: 'b' }),
createNode({ id: 'c' }),
]
const createFlowEdges = () => [
createEdge({
id: 'e1',
source: 'a',
target: 'b',
data: {
_sourceRunningStatus: NodeRunningStatus.Running,
_targetRunningStatus: NodeRunningStatus.Running,
_waitingRun: true,
},
}),
createEdge({
id: 'e2',
source: 'b',
target: 'c',
data: {
_sourceRunningStatus: NodeRunningStatus.Succeeded,
_targetRunningStatus: undefined,
_waitingRun: false,
},
}),
]
const renderEdgesInteractionsHook = () =>
renderWorkflowFlowHook(() => ({
...useEdgesInteractionsWithoutSync(),
edges: useEdges(),
}), {
nodes: createFlowNodes(),
edges: createFlowEdges(),
})
describe('useEdgesInteractionsWithoutSync', () => {
it('clears running status and waitingRun on all edges', () => {
const { result } = renderEdgesInteractionsHook()
act(() => {
result.current.handleEdgeCancelRunningStatus()
})
return waitFor(() => {
result.current.edges.forEach((edge) => {
const edgeState = getEdgeRuntimeState(edge)
expect(edgeState._sourceRunningStatus).toBeUndefined()
expect(edgeState._targetRunningStatus).toBeUndefined()
expect(edgeState._waitingRun).toBe(false)
})
})
})
it('does not mutate the original edges array', () => {
const edges = createFlowEdges()
const originalData = { ...getEdgeRuntimeState(edges[0]) }
const { result } = renderWorkflowFlowHook(() => ({
...useEdgesInteractionsWithoutSync(),
edges: useEdges(),
}), {
nodes: createFlowNodes(),
edges,
})
act(() => {
result.current.handleEdgeCancelRunningStatus()
})
expect(getEdgeRuntimeState(edges[0])._sourceRunningStatus).toBe(originalData._sourceRunningStatus)
})
})

View File

@ -0,0 +1,114 @@
import { createEdge, createNode } from '../../__tests__/fixtures'
import { getNodesConnectedSourceOrTargetHandleIdsMap } from '../../utils'
import {
applyConnectedHandleNodeData,
buildContextMenuEdges,
clearEdgeMenuIfNeeded,
clearNodeSelectionState,
updateEdgeHoverState,
updateEdgeSelectionState,
} from '../use-edges-interactions.helpers'
vi.mock('../../utils', () => ({
getNodesConnectedSourceOrTargetHandleIdsMap: vi.fn(),
}))
const mockGetNodesConnectedSourceOrTargetHandleIdsMap = vi.mocked(getNodesConnectedSourceOrTargetHandleIdsMap)
describe('use-edges-interactions.helpers', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('applyConnectedHandleNodeData should merge connected handle metadata into matching nodes', () => {
mockGetNodesConnectedSourceOrTargetHandleIdsMap.mockReturnValue({
'node-1': {
_connectedSourceHandleIds: ['branch-a'],
},
})
const nodes = [
createNode({ id: 'node-1', data: { title: 'Source' } }),
createNode({ id: 'node-2', data: { title: 'Target' } }),
]
const edgeChanges = [{
type: 'add',
edge: createEdge({ id: 'edge-1', source: 'node-1', target: 'node-2' }),
}]
const result = applyConnectedHandleNodeData(nodes, edgeChanges)
expect(result[0].data._connectedSourceHandleIds).toEqual(['branch-a'])
expect(result[1].data._connectedSourceHandleIds).toEqual([])
expect(mockGetNodesConnectedSourceOrTargetHandleIdsMap).toHaveBeenCalledWith(edgeChanges, nodes)
})
it('clearEdgeMenuIfNeeded should return true only when the open menu belongs to a removed edge', () => {
expect(clearEdgeMenuIfNeeded({
edgeMenu: { edgeId: 'edge-1' },
edgeIds: ['edge-1', 'edge-2'],
})).toBe(true)
expect(clearEdgeMenuIfNeeded({
edgeMenu: { edgeId: 'edge-3' },
edgeIds: ['edge-1', 'edge-2'],
})).toBe(false)
expect(clearEdgeMenuIfNeeded({
edgeIds: ['edge-1'],
})).toBe(false)
})
it('updateEdgeHoverState should toggle only the hovered edge flag', () => {
const edges = [
createEdge({ id: 'edge-1', data: { _hovering: false } }),
createEdge({ id: 'edge-2', data: { _hovering: false } }),
]
const result = updateEdgeHoverState(edges, 'edge-2', true)
expect(result.find(edge => edge.id === 'edge-1')?.data._hovering).toBe(false)
expect(result.find(edge => edge.id === 'edge-2')?.data._hovering).toBe(true)
})
it('updateEdgeSelectionState should update selected flags for select changes only', () => {
const edges = [
createEdge({ id: 'edge-1', selected: false }),
createEdge({ id: 'edge-2', selected: true }),
]
const result = updateEdgeSelectionState(edges, [
{ type: 'select', id: 'edge-1', selected: true },
{ type: 'remove', id: 'edge-2' },
])
expect(result.find(edge => edge.id === 'edge-1')?.selected).toBe(true)
expect(result.find(edge => edge.id === 'edge-2')?.selected).toBe(true)
})
it('buildContextMenuEdges should select the target edge and clear bundled markers', () => {
const edges = [
createEdge({ id: 'edge-1', selected: true, data: { _isBundled: true } }),
createEdge({ id: 'edge-2', selected: false, data: { _isBundled: true } }),
]
const result = buildContextMenuEdges(edges, 'edge-2')
expect(result.find(edge => edge.id === 'edge-1')?.selected).toBe(false)
expect(result.find(edge => edge.id === 'edge-2')?.selected).toBe(true)
expect(result.every(edge => edge.data._isBundled === false)).toBe(true)
})
it('clearNodeSelectionState should clear selected state and bundled markers on every node', () => {
const nodes = [
createNode({ id: 'node-1', selected: true, data: { selected: true, _isBundled: true } }),
createNode({ id: 'node-2', selected: false, data: { selected: true, _isBundled: true } }),
]
const result = clearNodeSelectionState(nodes)
expect(result.every(node => node.selected === false)).toBe(true)
expect(result.every(node => node.data.selected === false)).toBe(true)
expect(result.every(node => node.data._isBundled === false)).toBe(true)
})
})

Some files were not shown because too many files have changed in this diff Show More