mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 18:27:19 +08:00
Merge branch 'main' into jzh
This commit is contained in:
commit
e4c056a57a
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@ -7,6 +7,7 @@
|
||||
## Summary
|
||||
|
||||
<!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. -->
|
||||
<!-- If this PR was created by an automated agent, add `From <Tool Name>` as the final line of the description. Example: `From Codex`. -->
|
||||
|
||||
## Screenshots
|
||||
|
||||
|
||||
@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process.
|
||||
## Getting Help
|
||||
|
||||
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
||||
|
||||
## Automated Agent Contributions
|
||||
|
||||
> [!NOTE]
|
||||
> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, TypedDict
|
||||
from urllib.parse import parse_qsl, quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
@ -107,6 +107,17 @@ class KeywordStoreConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class SQLAlchemyEngineOptionsDict(TypedDict):
|
||||
pool_size: int
|
||||
max_overflow: int
|
||||
pool_recycle: int
|
||||
pool_pre_ping: bool
|
||||
connect_args: dict[str, str]
|
||||
pool_use_lifo: bool
|
||||
pool_reset_on_return: None
|
||||
pool_timeout: int
|
||||
|
||||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
# Database type selector
|
||||
DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
|
||||
@ -209,11 +220,11 @@ class DatabaseConfig(BaseSettings):
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> SQLAlchemyEngineOptionsDict:
|
||||
# Parse DB_EXTRAS for 'options'
|
||||
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
|
||||
options = db_extras_dict.get("options", "")
|
||||
connect_args = {}
|
||||
connect_args: dict[str, str] = {}
|
||||
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
|
||||
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
@ -223,7 +234,7 @@ class DatabaseConfig(BaseSettings):
|
||||
merged_options = timezone_opt
|
||||
connect_args = {"options": merged_options}
|
||||
|
||||
return {
|
||||
result: SQLAlchemyEngineOptionsDict = {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
|
||||
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
|
||||
@ -233,6 +244,7 @@ class DatabaseConfig(BaseSettings):
|
||||
"pool_reset_on_return": None,
|
||||
"pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class CeleryConfig(DatabaseConfig):
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import json
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
@ -18,30 +20,30 @@ from models.model import AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
class ModelConfigRequest(BaseModel):
|
||||
provider: str | None = Field(default=None, description="Model provider")
|
||||
model: str | None = Field(default=None, description="Model name")
|
||||
configs: dict[str, Any] | None = Field(default=None, description="Model configuration parameters")
|
||||
opening_statement: str | None = Field(default=None, description="Opening statement")
|
||||
suggested_questions: list[str] | None = Field(default=None, description="Suggested questions")
|
||||
more_like_this: dict[str, Any] | None = Field(default=None, description="More like this configuration")
|
||||
speech_to_text: dict[str, Any] | None = Field(default=None, description="Speech to text configuration")
|
||||
text_to_speech: dict[str, Any] | None = Field(default=None, description="Text to speech configuration")
|
||||
retrieval_model: dict[str, Any] | None = Field(default=None, description="Retrieval model configuration")
|
||||
tools: list[dict[str, Any]] | None = Field(default=None, description="Available tools")
|
||||
dataset_configs: dict[str, Any] | None = Field(default=None, description="Dataset configurations")
|
||||
agent_mode: dict[str, Any] | None = Field(default=None, description="Agent mode configuration")
|
||||
|
||||
|
||||
register_schema_models(console_ns, ModelConfigRequest)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/model-config")
|
||||
class ModelConfigResource(Resource):
|
||||
@console_ns.doc("update_app_model_config")
|
||||
@console_ns.doc(description="Update application model configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ModelConfigRequest",
|
||||
{
|
||||
"provider": fields.String(description="Model provider"),
|
||||
"model": fields.String(description="Model name"),
|
||||
"configs": fields.Raw(description="Model configuration parameters"),
|
||||
"opening_statement": fields.String(description="Opening statement"),
|
||||
"suggested_questions": fields.List(fields.String(), description="Suggested questions"),
|
||||
"more_like_this": fields.Raw(description="More like this configuration"),
|
||||
"speech_to_text": fields.Raw(description="Speech to text configuration"),
|
||||
"text_to_speech": fields.Raw(description="Text to speech configuration"),
|
||||
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
|
||||
"tools": fields.List(fields.Raw(), description="Available tools"),
|
||||
"dataset_configs": fields.Raw(description="Dataset configurations"),
|
||||
"agent_mode": fields.Raw(description="Agent mode configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ModelConfigRequest.__name__])
|
||||
@console_ns.response(200, "Model configuration updated successfully")
|
||||
@console_ns.response(400, "Invalid configuration")
|
||||
@console_ns.response(404, "App not found")
|
||||
|
||||
@ -2,18 +2,17 @@ import base64
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SubscriptionQuery(BaseModel):
|
||||
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
|
||||
@ -24,8 +23,7 @@ class PartnerTenantsPayload(BaseModel):
|
||||
click_id: str = Field(..., description="Click Id from partner referral link")
|
||||
|
||||
|
||||
for model in (SubscriptionQuery, PartnerTenantsPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
register_schema_models(console_ns, SubscriptionQuery, PartnerTenantsPayload)
|
||||
|
||||
|
||||
@console_ns.route("/billing/subscription")
|
||||
@ -58,12 +56,7 @@ class PartnerTenants(Resource):
|
||||
@console_ns.doc("sync_partner_tenants_bindings")
|
||||
@console_ns.doc(description="Sync partner tenants bindings")
|
||||
@console_ns.doc(params={"partner_key": "Partner key"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"SyncPartnerTenantsBindingsRequest",
|
||||
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[PartnerTenantsPayload.__name__])
|
||||
@console_ns.response(200, "Tenants synced to partner successfully")
|
||||
@console_ns.response(400, "Invalid partner information")
|
||||
@setup_required
|
||||
|
||||
@ -162,7 +162,9 @@ class DataSourceApi(Resource):
|
||||
binding_id = str(binding_id)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_tenant_id
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
@ -222,11 +224,11 @@ class DataSourceNotionListApi(Resource):
|
||||
raise ValueError("Dataset is not notion type.")
|
||||
|
||||
documents = session.scalars(
|
||||
select(Document).filter_by(
|
||||
dataset_id=query.dataset_id,
|
||||
tenant_id=current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
select(Document).where(
|
||||
Document.dataset_id == query.dataset_id,
|
||||
Document.tenant_id == current_tenant_id,
|
||||
Document.data_source_type == "notion_import",
|
||||
Document.enabled.is_(True),
|
||||
)
|
||||
).all()
|
||||
if documents:
|
||||
|
||||
@ -280,7 +280,7 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
|
||||
query = select(Document).where(Document.dataset_id == str(dataset_id), Document.tenant_id == current_tenant_id)
|
||||
|
||||
if status:
|
||||
query = DocumentService.apply_display_status_filter(query, status)
|
||||
|
||||
@ -527,7 +527,7 @@ class DocumentListApi(DatasetApiResource):
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||
query = select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == tenant_id)
|
||||
|
||||
if query_params.status:
|
||||
query = DocumentService.apply_display_status_filter(query, query_params.status)
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
from yarl import URL
|
||||
@ -179,6 +179,12 @@ class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
||||
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DatasourceInvokeMetaDict(TypedDict):
|
||||
time_cost: float
|
||||
error: str | None
|
||||
tool_config: dict[str, Any] | None
|
||||
|
||||
|
||||
class DatasourceInvokeMeta(BaseModel):
|
||||
"""
|
||||
Datasource invoke meta
|
||||
@ -202,12 +208,13 @@ class DatasourceInvokeMeta(BaseModel):
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error, tool_config={})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
def to_dict(self) -> DatasourceInvokeMetaDict:
|
||||
result: DatasourceInvokeMetaDict = {
|
||||
"time_cost": self.time_cost,
|
||||
"error": self.error,
|
||||
"tool_config": self.tool_config,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class DatasourceLabel(BaseModel):
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
@ -34,6 +34,13 @@ class ModelMode(StrEnum):
|
||||
prompt_file_contents: dict[str, Any] = {}
|
||||
|
||||
|
||||
class PromptTemplateConfigDict(TypedDict):
|
||||
prompt_template: PromptTemplateParser
|
||||
custom_variable_keys: list[str]
|
||||
special_variable_keys: list[str]
|
||||
prompt_rules: dict[str, Any]
|
||||
|
||||
|
||||
class SimplePromptTransform(PromptTransform):
|
||||
"""
|
||||
Simple Prompt Transform for Chatbot App Basic Mode.
|
||||
@ -105,18 +112,13 @@ class SimplePromptTransform(PromptTransform):
|
||||
with_memory_prompt=histories is not None,
|
||||
)
|
||||
|
||||
custom_variable_keys_obj = prompt_template_config["custom_variable_keys"]
|
||||
special_variable_keys_obj = prompt_template_config["special_variable_keys"]
|
||||
custom_variable_keys = prompt_template_config["custom_variable_keys"]
|
||||
if not isinstance(custom_variable_keys, list):
|
||||
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys)}")
|
||||
|
||||
# Type check for custom_variable_keys
|
||||
if not isinstance(custom_variable_keys_obj, list):
|
||||
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}")
|
||||
custom_variable_keys = cast(list[str], custom_variable_keys_obj)
|
||||
|
||||
# Type check for special_variable_keys
|
||||
if not isinstance(special_variable_keys_obj, list):
|
||||
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}")
|
||||
special_variable_keys = cast(list[str], special_variable_keys_obj)
|
||||
special_variable_keys = prompt_template_config["special_variable_keys"]
|
||||
if not isinstance(special_variable_keys, list):
|
||||
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys)}")
|
||||
|
||||
variables = {k: inputs[k] for k in custom_variable_keys if k in inputs}
|
||||
|
||||
@ -150,7 +152,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
has_context: bool,
|
||||
query_in_prompt: bool,
|
||||
with_memory_prompt: bool = False,
|
||||
) -> dict[str, object]:
|
||||
) -> PromptTemplateConfigDict:
|
||||
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
|
||||
|
||||
custom_variable_keys: list[str] = []
|
||||
@ -173,12 +175,13 @@ class SimplePromptTransform(PromptTransform):
|
||||
prompt += prompt_rules.get("query_prompt", "{{#query#}}")
|
||||
special_variable_keys.append("#query#")
|
||||
|
||||
return {
|
||||
result: PromptTemplateConfigDict = {
|
||||
"prompt_template": PromptTemplateParser(template=prompt),
|
||||
"custom_variable_keys": custom_variable_keys,
|
||||
"special_variable_keys": special_variable_keys,
|
||||
"prompt_rules": prompt_rules,
|
||||
}
|
||||
return result
|
||||
|
||||
def _get_chat_model_prompt_messages(
|
||||
self,
|
||||
|
||||
@ -6,7 +6,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import delete, func, select, update
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
@ -63,11 +63,11 @@ class IndexProcessor:
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
) -> IndexingResultDict:
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
|
||||
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
||||
|
||||
@ -104,12 +104,12 @@ class IndexProcessor:
|
||||
document.indexing_status = "completed"
|
||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.word_count = (
|
||||
session.query(func.sum(DocumentSegment.word_count))
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
session.scalar(
|
||||
select(func.sum(DocumentSegment.word_count)).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
) or 0
|
||||
# Update need_summary based on dataset's summary_index_setting
|
||||
if summary_index_setting and summary_index_setting.get("enable") is True:
|
||||
@ -118,15 +118,17 @@ class IndexProcessor:
|
||||
document.need_summary = False
|
||||
session.add(document)
|
||||
# update document segment status
|
||||
session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
}
|
||||
session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
)
|
||||
.values(
|
||||
status="completed",
|
||||
enabled=True,
|
||||
completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
)
|
||||
)
|
||||
|
||||
result: IndexingResultDict = {
|
||||
@ -151,11 +153,11 @@ class IndexProcessor:
|
||||
doc_language = None
|
||||
with session_factory.create_session() as session:
|
||||
if document_id:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
else:
|
||||
document = None
|
||||
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
||||
|
||||
|
||||
@ -159,14 +159,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.filter(
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
@ -163,14 +164,12 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.filter(
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
|
||||
@ -14,7 +14,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMU
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy import and_, func, literal, or_, select, update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
@ -276,8 +276,8 @@ class DatasetRetrieval:
|
||||
document_ids = [i.segment.document_id for i in records]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
|
||||
documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all()
|
||||
datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
|
||||
documents = session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all()
|
||||
|
||||
dataset_map = {i.id: i for i in datasets}
|
||||
document_map = {i.id: i for i in documents}
|
||||
@ -971,9 +971,11 @@ class DatasetRetrieval:
|
||||
|
||||
# Batch update hit_count for all segments
|
||||
if segment_ids_to_update:
|
||||
session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(DocumentSegment.id.in_(segment_ids_to_update))
|
||||
.values(hit_count=DocumentSegment.hit_count + 1)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
@ -1822,7 +1824,7 @@ class DatasetRetrieval:
|
||||
def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]:
|
||||
with session_factory.create_session() as session:
|
||||
subquery = (
|
||||
session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
|
||||
select(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
|
||||
.where(
|
||||
DocumentModel.indexing_status == "completed",
|
||||
DocumentModel.enabled == True,
|
||||
@ -1834,13 +1836,12 @@ class DatasetRetrieval:
|
||||
.subquery()
|
||||
)
|
||||
|
||||
results = (
|
||||
session.query(Dataset)
|
||||
results = session.scalars(
|
||||
select(Dataset)
|
||||
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
|
||||
.where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
available_datasets = []
|
||||
for dataset in results:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
@ -21,7 +23,7 @@ class SummaryIndex:
|
||||
) -> None:
|
||||
if is_preview:
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return
|
||||
|
||||
@ -34,32 +36,31 @@ class SummaryIndex:
|
||||
if not document_id:
|
||||
return
|
||||
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
# Skip qa_model documents
|
||||
if document is None or document.doc_form == "qa_model":
|
||||
return
|
||||
|
||||
query = session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
segments = query.all()
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
if not segment_ids:
|
||||
return
|
||||
|
||||
existing_summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
existing_summaries = session.scalars(
|
||||
select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
completed_summary_segment_ids = {i.chunk_id for i in existing_summaries}
|
||||
# Preview mode should process segments that are MISSING completed summaries
|
||||
pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids]
|
||||
@ -73,7 +74,7 @@ class SummaryIndex:
|
||||
def process_segment(segment_id: str) -> None:
|
||||
"""Process a single segment in a thread with a fresh DB session."""
|
||||
with session_factory.create_session() as session:
|
||||
segment = session.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1))
|
||||
if segment is None:
|
||||
return
|
||||
try:
|
||||
|
||||
@ -450,6 +450,12 @@ class WorkflowToolParameterConfiguration(BaseModel):
|
||||
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
|
||||
|
||||
|
||||
class ToolInvokeMetaDict(TypedDict):
|
||||
time_cost: float
|
||||
error: str | None
|
||||
tool_config: dict[str, Any] | None
|
||||
|
||||
|
||||
class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
Tool invoke meta
|
||||
@ -473,12 +479,13 @@ class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error, tool_config={})
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
def to_dict(self) -> ToolInvokeMetaDict:
|
||||
result: ToolInvokeMetaDict = {
|
||||
"time_cost": self.time_cost,
|
||||
"error": self.error,
|
||||
"tool_config": self.tool_config,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class ToolLabel(BaseModel):
|
||||
|
||||
@ -262,6 +262,8 @@ class ToolEngine:
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
continue
|
||||
else:
|
||||
parts.append(str(response.message))
|
||||
|
||||
|
||||
@ -682,7 +682,7 @@ class ToolManager:
|
||||
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
|
||||
return list(session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids))))
|
||||
|
||||
@classmethod
|
||||
def list_providers_from_api(
|
||||
@ -993,7 +993,7 @@ class ToolManager:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str:
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | str:
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
@ -1001,7 +1001,7 @@ class ToolManager:
|
||||
mcp_provider = mcp_service.get_provider_entity(
|
||||
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
|
||||
)
|
||||
return mcp_provider.provider_icon
|
||||
return cast(EmojiIconDict | str, mcp_provider.provider_icon)
|
||||
except ValueError:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
except Exception:
|
||||
@ -1013,7 +1013,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
provider_type: ToolProviderType,
|
||||
provider_id: str,
|
||||
) -> str | EmojiIconDict | dict[str, str]:
|
||||
) -> str | EmojiIconDict:
|
||||
"""
|
||||
get the tool icon
|
||||
|
||||
|
||||
@ -17,10 +17,8 @@ class WorkflowToolConfigurationUtils:
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -14,6 +14,7 @@ from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
from redis.retry import Retry
|
||||
from redis.sentinel import Sentinel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
@ -126,6 +127,35 @@ redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
_pubsub_redis_client: redis.Redis | RedisCluster | None = None
|
||||
|
||||
|
||||
class RedisSSLParamsDict(TypedDict):
|
||||
ssl_cert_reqs: int
|
||||
ssl_ca_certs: str | None
|
||||
ssl_certfile: str | None
|
||||
ssl_keyfile: str | None
|
||||
|
||||
|
||||
class RedisHealthParamsDict(TypedDict):
|
||||
retry: Retry
|
||||
socket_timeout: float | None
|
||||
socket_connect_timeout: float | None
|
||||
health_check_interval: int | None
|
||||
|
||||
|
||||
class RedisBaseParamsDict(TypedDict):
|
||||
username: str | None
|
||||
password: str | None
|
||||
db: int
|
||||
encoding: str
|
||||
encoding_errors: str
|
||||
decode_responses: bool
|
||||
protocol: int
|
||||
cache_config: CacheConfig | None
|
||||
retry: Retry
|
||||
socket_timeout: float | None
|
||||
socket_connect_timeout: float | None
|
||||
health_check_interval: int | None
|
||||
|
||||
|
||||
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
||||
"""Get SSL configuration for Redis connection."""
|
||||
if not dify_config.REDIS_USE_SSL:
|
||||
@ -171,14 +201,14 @@ def _get_retry_policy() -> Retry:
|
||||
)
|
||||
|
||||
|
||||
def _get_connection_health_params() -> dict[str, Any]:
|
||||
def _get_connection_health_params() -> RedisHealthParamsDict:
|
||||
"""Get connection health and retry parameters for standalone and Sentinel Redis clients."""
|
||||
return {
|
||||
"retry": _get_retry_policy(),
|
||||
"socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT,
|
||||
"socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
|
||||
"health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL,
|
||||
}
|
||||
return RedisHealthParamsDict(
|
||||
retry=_get_retry_policy(),
|
||||
socket_timeout=dify_config.REDIS_SOCKET_TIMEOUT,
|
||||
socket_connect_timeout=dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
|
||||
health_check_interval=dify_config.REDIS_HEALTH_CHECK_INTERVAL,
|
||||
)
|
||||
|
||||
|
||||
def _get_cluster_connection_health_params() -> dict[str, Any]:
|
||||
@ -189,26 +219,26 @@ def _get_cluster_connection_health_params() -> dict[str, Any]:
|
||||
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
|
||||
are passed through.
|
||||
"""
|
||||
params = _get_connection_health_params()
|
||||
params: dict[str, Any] = dict(_get_connection_health_params())
|
||||
return {k: v for k, v in params.items() if k != "health_check_interval"}
|
||||
|
||||
|
||||
def _get_base_redis_params() -> dict[str, Any]:
|
||||
def _get_base_redis_params() -> RedisBaseParamsDict:
|
||||
"""Get base Redis connection parameters including retry and health policy."""
|
||||
return {
|
||||
"username": dify_config.REDIS_USERNAME,
|
||||
"password": dify_config.REDIS_PASSWORD or None,
|
||||
"db": dify_config.REDIS_DB,
|
||||
"encoding": "utf-8",
|
||||
"encoding_errors": "strict",
|
||||
"decode_responses": False,
|
||||
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
"cache_config": _get_cache_configuration(),
|
||||
return RedisBaseParamsDict(
|
||||
username=dify_config.REDIS_USERNAME,
|
||||
password=dify_config.REDIS_PASSWORD or None,
|
||||
db=dify_config.REDIS_DB,
|
||||
encoding="utf-8",
|
||||
encoding_errors="strict",
|
||||
decode_responses=False,
|
||||
protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
cache_config=_get_cache_configuration(),
|
||||
**_get_connection_health_params(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
||||
def _create_sentinel_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
|
||||
"""Create Redis client using Sentinel configuration."""
|
||||
if not dify_config.REDIS_SENTINELS:
|
||||
raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True")
|
||||
@ -232,7 +262,8 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis,
|
||||
sentinel_kwargs=sentinel_kwargs,
|
||||
)
|
||||
|
||||
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
|
||||
params: dict[str, Any] = {**redis_params}
|
||||
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **params)
|
||||
return master
|
||||
|
||||
|
||||
@ -259,18 +290,16 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
|
||||
return cluster
|
||||
|
||||
|
||||
def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
||||
def _create_standalone_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
|
||||
"""Create standalone Redis client."""
|
||||
connection_class, ssl_kwargs = _get_ssl_configuration()
|
||||
|
||||
params = {**redis_params}
|
||||
params.update(
|
||||
{
|
||||
"host": dify_config.REDIS_HOST,
|
||||
"port": dify_config.REDIS_PORT,
|
||||
"connection_class": connection_class,
|
||||
}
|
||||
)
|
||||
params: dict[str, Any] = {
|
||||
**redis_params,
|
||||
"host": dify_config.REDIS_HOST,
|
||||
"port": dify_config.REDIS_PORT,
|
||||
"connection_class": connection_class,
|
||||
}
|
||||
|
||||
if dify_config.REDIS_MAX_CONNECTIONS:
|
||||
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
||||
@ -293,8 +322,8 @@ def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis |
|
||||
kwargs["max_connections"] = max_conns
|
||||
return RedisCluster.from_url(pubsub_url, **kwargs)
|
||||
|
||||
health_params = _get_connection_health_params()
|
||||
kwargs = {**health_params}
|
||||
standalone_health_params: dict[str, Any] = dict(_get_connection_health_params())
|
||||
kwargs = {**standalone_health_params}
|
||||
if max_conns:
|
||||
kwargs["max_connections"] = max_conns
|
||||
return redis.Redis.from_url(pubsub_url, **kwargs)
|
||||
|
||||
@ -14,9 +14,15 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import redis
|
||||
from redis.cluster import RedisCluster
|
||||
from redis.exceptions import LockNotOwnedError, RedisError
|
||||
from redis.lock import Lock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -38,21 +44,21 @@ class DbMigrationAutoRenewLock:
|
||||
primary error/exit code.
|
||||
"""
|
||||
|
||||
_redis_client: Any
|
||||
_redis_client: redis.Redis | RedisCluster | RedisClientWrapper
|
||||
_name: str
|
||||
_ttl_seconds: float
|
||||
_renew_interval_seconds: float
|
||||
_log_context: str | None
|
||||
_logger: logging.Logger
|
||||
|
||||
_lock: Any
|
||||
_lock: Lock | None
|
||||
_stop_event: threading.Event | None
|
||||
_thread: threading.Thread | None
|
||||
_acquired: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Any,
|
||||
redis_client: redis.Redis | RedisCluster | RedisClientWrapper,
|
||||
name: str,
|
||||
ttl_seconds: float = 60,
|
||||
renew_interval_seconds: float | None = None,
|
||||
@ -127,7 +133,7 @@ class DbMigrationAutoRenewLock:
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
|
||||
def _heartbeat_loop(self, lock: Lock, stop_event: threading.Event) -> None:
|
||||
while not stop_event.wait(self._renew_interval_seconds):
|
||||
try:
|
||||
lock.reacquire()
|
||||
|
||||
@ -24,6 +24,8 @@ class TypeBase(MappedAsDataclass, DeclarativeBase):
|
||||
|
||||
|
||||
class DefaultFieldsMixin:
|
||||
"""Mixin for models that inherit from Base (non-dataclass)."""
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
@ -53,6 +55,42 @@ class DefaultFieldsMixin:
|
||||
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||
|
||||
|
||||
class DefaultFieldsDCMixin(MappedAsDataclass):
|
||||
"""Mixin for models that inherit from TypeBase (MappedAsDataclass)."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
insert_default=lambda: str(uuidv7()),
|
||||
default_factory=lambda: str(uuidv7()),
|
||||
init=False,
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
insert_default=naive_utc_now,
|
||||
default_factory=naive_utc_now,
|
||||
init=False,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
insert_default=naive_utc_now,
|
||||
default_factory=naive_utc_now,
|
||||
init=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||
|
||||
|
||||
def gen_uuidv4_string() -> str:
|
||||
"""gen_uuidv4_string generate a UUIDv4 string.
|
||||
|
||||
|
||||
@ -108,6 +108,56 @@ class ExternalKnowledgeApiDict(TypedDict):
|
||||
created_at: str
|
||||
|
||||
|
||||
class DocumentDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
position: int
|
||||
data_source_type: str
|
||||
data_source_info: str | None
|
||||
dataset_process_rule_id: str | None
|
||||
batch: str
|
||||
name: str
|
||||
created_from: str
|
||||
created_by: str
|
||||
created_api_request_id: str | None
|
||||
created_at: datetime
|
||||
processing_started_at: datetime | None
|
||||
file_id: str | None
|
||||
word_count: int | None
|
||||
parsing_completed_at: datetime | None
|
||||
cleaning_completed_at: datetime | None
|
||||
splitting_completed_at: datetime | None
|
||||
tokens: int | None
|
||||
indexing_latency: float | None
|
||||
completed_at: datetime | None
|
||||
is_paused: bool | None
|
||||
paused_by: str | None
|
||||
paused_at: datetime | None
|
||||
error: str | None
|
||||
stopped_at: datetime | None
|
||||
indexing_status: str
|
||||
enabled: bool
|
||||
disabled_at: datetime | None
|
||||
disabled_by: str | None
|
||||
archived: bool
|
||||
archived_reason: str | None
|
||||
archived_by: str | None
|
||||
archived_at: datetime | None
|
||||
updated_at: datetime
|
||||
doc_type: str | None
|
||||
doc_metadata: Any
|
||||
doc_form: IndexStructureType
|
||||
doc_language: str | None
|
||||
display_status: str | None
|
||||
data_source_info_dict: dict[str, Any]
|
||||
average_segment_length: int
|
||||
dataset_process_rule: ProcessRuleDict | None
|
||||
dataset: None
|
||||
segment_count: int | None
|
||||
hit_count: int | None
|
||||
|
||||
|
||||
class DatasetPermissionEnum(enum.StrEnum):
|
||||
ONLY_ME = "only_me"
|
||||
ALL_TEAM = "all_team_members"
|
||||
@ -675,8 +725,8 @@ class Document(Base):
|
||||
)
|
||||
return built_in_fields
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
def to_dict(self) -> DocumentDict:
|
||||
result: DocumentDict = {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"dataset_id": self.dataset_id,
|
||||
@ -721,10 +771,11 @@ class Document(Base):
|
||||
"data_source_info_dict": self.data_source_info_dict,
|
||||
"average_segment_length": self.average_segment_length,
|
||||
"dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
|
||||
"dataset": None, # Dataset class doesn't have a to_dict method
|
||||
"dataset": None,
|
||||
"segment_count": self.segment_count,
|
||||
"hit_count": self.hit_count,
|
||||
}
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]):
|
||||
|
||||
@ -809,11 +809,11 @@ class AccountService:
|
||||
rest of the system gradually normalizes new inputs.
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||
return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
||||
|
||||
@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
import click
|
||||
from flask import Flask, current_app
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -62,13 +62,11 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
for model, table_name in related_tables:
|
||||
# Query records related to expired messages
|
||||
records = (
|
||||
session.query(model)
|
||||
.where(
|
||||
records = session.scalars(
|
||||
select(model).where(
|
||||
model.message_id.in_(batch_message_ids), # type: ignore
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if len(records) == 0:
|
||||
continue
|
||||
@ -103,9 +101,13 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
except Exception:
|
||||
logger.exception("Failed to save %s records", table_name)
|
||||
|
||||
session.query(model).where(
|
||||
model.id.in_(record_ids), # type: ignore
|
||||
).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(model)
|
||||
.where(
|
||||
model.id.in_(record_ids), # type: ignore
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
@ -121,15 +123,14 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
app_ids = [app.id for app in apps]
|
||||
while True:
|
||||
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
|
||||
messages = (
|
||||
session.query(Message)
|
||||
messages = session.scalars(
|
||||
select(Message)
|
||||
.where(
|
||||
Message.app_id.in_(app_ids),
|
||||
Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if len(messages) == 0:
|
||||
break
|
||||
|
||||
@ -147,9 +148,9 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
message_ids = [message.id for message in messages]
|
||||
|
||||
# delete messages
|
||||
session.query(Message).where(
|
||||
Message.id.in_(message_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(Message).where(Message.id.in_(message_ids)).execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
cls._clear_message_related_tables(session, tenant_id, message_ids)
|
||||
|
||||
@ -161,15 +162,14 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
while True:
|
||||
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
|
||||
conversations = (
|
||||
session.query(Conversation)
|
||||
conversations = session.scalars(
|
||||
select(Conversation)
|
||||
.where(
|
||||
Conversation.app_id.in_(app_ids),
|
||||
Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if len(conversations) == 0:
|
||||
break
|
||||
@ -186,9 +186,11 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
)
|
||||
|
||||
conversation_ids = [conversation.id for conversation in conversations]
|
||||
session.query(Conversation).where(
|
||||
Conversation.id.in_(conversation_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(Conversation)
|
||||
.where(Conversation.id.in_(conversation_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
@ -293,15 +295,14 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
while True:
|
||||
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
|
||||
workflow_app_logs = (
|
||||
session.query(WorkflowAppLog)
|
||||
workflow_app_logs = session.scalars(
|
||||
select(WorkflowAppLog)
|
||||
.where(
|
||||
WorkflowAppLog.tenant_id == tenant_id,
|
||||
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if len(workflow_app_logs) == 0:
|
||||
break
|
||||
@ -321,8 +322,10 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
|
||||
|
||||
# delete workflow app logs
|
||||
session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(WorkflowAppLog)
|
||||
.where(WorkflowAppLog.id.in_(workflow_app_log_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
click.echo(
|
||||
@ -344,7 +347,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
current_time = started_at
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
total_tenant_count = session.query(Tenant.id).count()
|
||||
total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0
|
||||
|
||||
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
|
||||
|
||||
@ -409,9 +412,12 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
tenant_count = 0
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
.where(Tenant.created_at.between(current_time, current_time + test_interval))
|
||||
.count()
|
||||
session.scalar(
|
||||
select(func.count(Tenant.id)).where(
|
||||
Tenant.created_at.between(current_time, current_time + test_interval)
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if tenant_count <= 100:
|
||||
interval = test_interval
|
||||
@ -433,8 +439,8 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
batch_end = min(current_time + interval, ended_at)
|
||||
|
||||
rs = (
|
||||
session.query(Tenant.id)
|
||||
rs = session.execute(
|
||||
select(Tenant.id)
|
||||
.where(Tenant.created_at.between(current_time, batch_end))
|
||||
.order_by(Tenant.created_at)
|
||||
)
|
||||
|
||||
@ -552,8 +552,8 @@ class DatasetService:
|
||||
external_knowledge_api_id: External knowledge API identifier
|
||||
"""
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
external_knowledge_binding = (
|
||||
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
|
||||
external_knowledge_binding = session.scalar(
|
||||
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not external_knowledge_binding:
|
||||
@ -1454,15 +1454,17 @@ class DocumentService:
|
||||
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
updated_count = (
|
||||
session.query(Document)
|
||||
.filter(
|
||||
result = session.execute(
|
||||
update(Document)
|
||||
.where(
|
||||
Document.id.in_(document_id_list),
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
|
||||
)
|
||||
.update({Document.need_summary: need_summary}, synchronize_session=False)
|
||||
.values(need_summary=need_summary)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
updated_count = result.rowcount # type: ignore[union-attr,attr-defined]
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Updated need_summary to %s for %d documents in dataset %s",
|
||||
@ -2822,6 +2824,10 @@ class DocumentService:
|
||||
|
||||
knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values())
|
||||
|
||||
if knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL:
|
||||
if not knowledge_config.process_rule.rules.parent_mode:
|
||||
knowledge_config.process_rule.rules.parent_mode = "paragraph"
|
||||
|
||||
if not knowledge_config.process_rule.rules.segmentation:
|
||||
raise ValueError("Process rule segmentation is required")
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.provider_entities import FormType
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -54,11 +54,13 @@ class DatasourceProviderService:
|
||||
remove oauth custom client params
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(DatasourceOauthTenantParamConfig).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
).delete()
|
||||
session.execute(
|
||||
delete(DatasourceOauthTenantParamConfig).where(
|
||||
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
|
||||
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
|
||||
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
)
|
||||
|
||||
def decrypt_datasource_provider_credentials(
|
||||
self,
|
||||
@ -110,15 +112,21 @@ class DatasourceProviderService:
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
if credential_id:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
datasource_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.id == credential_id)
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
datasource_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not datasource_provider:
|
||||
return {}
|
||||
@ -173,12 +181,15 @@ class DatasourceProviderService:
|
||||
get all datasource credentials by provider
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_providers = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
datasource_providers = session.scalars(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if not datasource_providers:
|
||||
return []
|
||||
current_user = get_current_user()
|
||||
@ -232,15 +243,15 @@ class DatasourceProviderService:
|
||||
update datasource provider name
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
id=credential_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
target_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.id == credential_id,
|
||||
DatasourceProvider.provider == datasource_provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
@ -250,16 +261,16 @@ class DatasourceProviderService:
|
||||
|
||||
# check name is exist
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.name == name,
|
||||
DatasourceProvider.provider == datasource_provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
or 0
|
||||
) > 0:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
|
||||
target_provider.name = name
|
||||
@ -273,26 +284,31 @@ class DatasourceProviderService:
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get provider
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
id=credential_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
target_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.id == credential_id,
|
||||
DatasourceProvider.provider == datasource_provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# clear default provider
|
||||
session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=target_provider.provider,
|
||||
plugin_id=target_provider.plugin_id,
|
||||
is_default=True,
|
||||
).update({"is_default": False})
|
||||
session.execute(
|
||||
update(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == target_provider.provider,
|
||||
DatasourceProvider.plugin_id == target_provider.plugin_id,
|
||||
DatasourceProvider.is_default.is_(True),
|
||||
)
|
||||
.values(is_default=False)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
@ -311,14 +327,14 @@ class DatasourceProviderService:
|
||||
if client_params is None and enabled is None:
|
||||
return
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
tenant_oauth_client_params = session.scalar(
|
||||
select(DatasourceOauthTenantParamConfig)
|
||||
.where(
|
||||
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
|
||||
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
|
||||
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not tenant_oauth_client_params:
|
||||
@ -351,9 +367,14 @@ class DatasourceProviderService:
|
||||
"""
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
return (
|
||||
session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
|
||||
.first()
|
||||
session.scalar(
|
||||
select(DatasourceOauthParamConfig)
|
||||
.where(
|
||||
DatasourceOauthParamConfig.provider == datasource_provider_id.provider_name,
|
||||
DatasourceOauthParamConfig.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
@ -423,15 +444,15 @@ class DatasourceProviderService:
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
# get tenant oauth client params
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
enabled=True,
|
||||
tenant_oauth_client_params = session.scalar(
|
||||
select(DatasourceOauthTenantParamConfig)
|
||||
.where(
|
||||
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
|
||||
DatasourceOauthTenantParamConfig.provider == provider,
|
||||
DatasourceOauthTenantParamConfig.plugin_id == plugin_id,
|
||||
DatasourceOauthTenantParamConfig.enabled.is_(True),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
@ -443,8 +464,13 @@ class DatasourceProviderService:
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
|
||||
if is_verified:
|
||||
# fallback to system oauth client params
|
||||
oauth_client_params = (
|
||||
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||
oauth_client_params = session.scalar(
|
||||
select(DatasourceOauthParamConfig)
|
||||
.where(
|
||||
DatasourceOauthParamConfig.provider == provider,
|
||||
DatasourceOauthParamConfig.plugin_id == plugin_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if oauth_client_params:
|
||||
return oauth_client_params.system_credentials
|
||||
@ -455,15 +481,13 @@ class DatasourceProviderService:
|
||||
def generate_next_datasource_provider_name(
|
||||
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
|
||||
) -> str:
|
||||
db_providers = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
db_providers = session.scalars(
|
||||
select(DatasourceProvider).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
return generate_incremental_name(
|
||||
[provider.name for provider in db_providers],
|
||||
f"{credential_type.get_name()}",
|
||||
@ -485,8 +509,10 @@ class DatasourceProviderService:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
|
||||
target_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(DatasourceProvider.id == credential_id, DatasourceProvider.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
@ -496,25 +522,28 @@ class DatasourceProviderService:
|
||||
db_provider_name = target_provider.name
|
||||
else:
|
||||
name_conflict = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=db_provider_name,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
auth_type=CredentialType.OAUTH2.value,
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.name == db_provider_name,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
DatasourceProvider.auth_type == CredentialType.OAUTH2.value,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
if name_conflict > 0:
|
||||
db_provider_name = generate_incremental_name(
|
||||
[
|
||||
provider.name
|
||||
for provider in session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
)
|
||||
for provider in session.scalars(
|
||||
select(DatasourceProvider).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
)
|
||||
).all()
|
||||
],
|
||||
db_provider_name,
|
||||
)
|
||||
@ -556,25 +585,27 @@ class DatasourceProviderService:
|
||||
)
|
||||
else:
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=db_provider_name,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
auth_type=credential_type.value,
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.name == db_provider_name,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
DatasourceProvider.auth_type == credential_type.value,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
or 0
|
||||
) > 0:
|
||||
db_provider_name = generate_incremental_name(
|
||||
[
|
||||
provider.name
|
||||
for provider in session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
)
|
||||
for provider in session.scalars(
|
||||
select(DatasourceProvider).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
)
|
||||
).all()
|
||||
],
|
||||
db_provider_name,
|
||||
)
|
||||
@ -627,11 +658,16 @@ class DatasourceProviderService:
|
||||
|
||||
# check name is exist
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
DatasourceProvider.provider == provider_name,
|
||||
DatasourceProvider.name == db_provider_name,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
) > 0:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
|
||||
try:
|
||||
@ -918,21 +954,31 @@ class DatasourceProviderService:
|
||||
"""
|
||||
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
|
||||
.first()
|
||||
datasource_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.id == auth_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not datasource_provider:
|
||||
raise ValueError("Datasource provider not found")
|
||||
# update name
|
||||
if name and name != datasource_provider.name:
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.name == name,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
) > 0:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
datasource_provider.name = name
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ import sqlalchemy as sa
|
||||
import tqdm
|
||||
from flask import Flask, current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
@ -66,7 +67,7 @@ class PluginMigration:
|
||||
current_time = started_at
|
||||
|
||||
with Session(db.engine) as session:
|
||||
total_tenant_count = session.query(Tenant.id).count()
|
||||
total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0
|
||||
|
||||
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
|
||||
|
||||
@ -123,9 +124,12 @@ class PluginMigration:
|
||||
tenant_count = 0
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
.where(Tenant.created_at.between(current_time, current_time + test_interval))
|
||||
.count()
|
||||
session.scalar(
|
||||
select(func.count(Tenant.id)).where(
|
||||
Tenant.created_at.between(current_time, current_time + test_interval)
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if tenant_count <= 100:
|
||||
interval = test_interval
|
||||
@ -147,8 +151,8 @@ class PluginMigration:
|
||||
|
||||
batch_end = min(current_time + interval, ended_at)
|
||||
|
||||
rs = (
|
||||
session.query(Tenant.id)
|
||||
rs = session.execute(
|
||||
select(Tenant.id)
|
||||
.where(Tenant.created_at.between(current_time, batch_end))
|
||||
.order_by(Tenant.created_at)
|
||||
)
|
||||
@ -235,7 +239,7 @@ class PluginMigration:
|
||||
Extract tool tables.
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
rs = session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id)).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
result.append(ToolProviderID(row.provider).plugin_id)
|
||||
@ -249,7 +253,7 @@ class PluginMigration:
|
||||
"""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all()
|
||||
rs = session.scalars(select(Workflow).where(Workflow.tenant_id == tenant_id)).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
graph = row.graph_dict
|
||||
@ -272,7 +276,7 @@ class PluginMigration:
|
||||
Extract app tables.
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
apps = session.query(App).where(App.tenant_id == tenant_id).all()
|
||||
apps = session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
|
||||
if not apps:
|
||||
return []
|
||||
|
||||
@ -280,7 +284,7 @@ class PluginMigration:
|
||||
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT
|
||||
]
|
||||
|
||||
rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
|
||||
rs = session.scalars(select(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids))).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
agent_config = row.agent_mode_dict
|
||||
|
||||
@ -283,7 +283,9 @@ class RagPipelineDslService:
|
||||
):
|
||||
raise ValueError("Chunk structure is not compatible with the published pipeline")
|
||||
if not dataset:
|
||||
datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
|
||||
datasets = self._session.scalars(
|
||||
select(Dataset).where(Dataset.tenant_id == account.current_tenant_id)
|
||||
).all()
|
||||
names = [dataset.name for dataset in datasets]
|
||||
generate_name = generate_incremental_name(names, name)
|
||||
dataset = Dataset(
|
||||
@ -303,8 +305,8 @@ class RagPipelineDslService:
|
||||
chunk_structure=knowledge_configuration.chunk_structure,
|
||||
)
|
||||
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
dataset_collection_binding = (
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
dataset_collection_binding = self._session.scalar(
|
||||
select(DatasetCollectionBinding)
|
||||
.where(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
@ -312,7 +314,7 @@ class RagPipelineDslService:
|
||||
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not dataset_collection_binding:
|
||||
@ -440,8 +442,8 @@ class RagPipelineDslService:
|
||||
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
dataset_collection_binding = (
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
dataset_collection_binding = self._session.scalar(
|
||||
select(DatasetCollectionBinding)
|
||||
.where(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
@ -449,7 +451,7 @@ class RagPipelineDslService:
|
||||
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not dataset_collection_binding:
|
||||
@ -591,14 +593,14 @@ class RagPipelineDslService:
|
||||
IMPORT_INFO_REDIS_EXPIRY,
|
||||
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
|
||||
)
|
||||
workflow = (
|
||||
self._session.query(Workflow)
|
||||
workflow = self._session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# create draft workflow if not found
|
||||
@ -665,14 +667,12 @@ class RagPipelineDslService:
|
||||
:param pipeline: Pipeline instance
|
||||
"""
|
||||
|
||||
workflow = (
|
||||
self._session.query(Workflow)
|
||||
.where(
|
||||
workflow = self._session.scalar(
|
||||
select(Workflow).where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not workflow:
|
||||
raise ValueError("Missing draft workflow configuration, please check.")
|
||||
@ -904,15 +904,16 @@ class RagPipelineDslService:
|
||||
):
|
||||
if rag_pipeline_dataset_create_entity.name:
|
||||
# check if dataset name already exists
|
||||
if (
|
||||
self._session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
if self._session.scalar(
|
||||
select(Dataset).where(
|
||||
Dataset.name == rag_pipeline_dataset_create_entity.name,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
):
|
||||
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
|
||||
else:
|
||||
# generate a random name as Untitled 1 2 3 ...
|
||||
datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all()
|
||||
datasets = self._session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all()
|
||||
names = [dataset.name for dataset in datasets]
|
||||
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
|
||||
names,
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import TypedDict, cast
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
@ -109,8 +110,13 @@ class SummaryIndexService:
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
# Check if summary record already exists
|
||||
existing_summary = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
existing_summary = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if existing_summary:
|
||||
@ -309,8 +315,10 @@ class SummaryIndexService:
|
||||
summary_record_id,
|
||||
segment.id,
|
||||
)
|
||||
summary_record_in_session = (
|
||||
session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
|
||||
summary_record_in_session = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(DocumentSegmentSummary.id == summary_record_id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not summary_record_in_session:
|
||||
@ -323,10 +331,13 @@ class SummaryIndexService:
|
||||
dataset.id,
|
||||
segment.id,
|
||||
)
|
||||
summary_record_in_session = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
|
||||
.first()
|
||||
summary_record_in_session = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not summary_record_in_session:
|
||||
@ -487,8 +498,10 @@ class SummaryIndexService:
|
||||
with session_factory.create_session() as error_session:
|
||||
# Try to find the record by id first
|
||||
# Note: Using assignment only (no type annotation) to avoid redeclaration error
|
||||
summary_record_in_session = (
|
||||
error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
|
||||
summary_record_in_session = error_session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(DocumentSegmentSummary.id == summary_record_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not summary_record_in_session:
|
||||
# Try to find by chunk_id and dataset_id
|
||||
@ -500,10 +513,13 @@ class SummaryIndexService:
|
||||
dataset.id,
|
||||
segment.id,
|
||||
)
|
||||
summary_record_in_session = (
|
||||
error_session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
|
||||
.first()
|
||||
summary_record_in_session = error_session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if summary_record_in_session:
|
||||
@ -551,14 +567,12 @@ class SummaryIndexService:
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
# Query existing summary records
|
||||
existing_summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
existing_summaries = session.scalars(
|
||||
select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries}
|
||||
|
||||
# Create or update records
|
||||
@ -603,8 +617,13 @@ class SummaryIndexService:
|
||||
error: Error message
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if summary_record:
|
||||
@ -639,8 +658,13 @@ class SummaryIndexService:
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
# Get or refresh summary record in this session
|
||||
summary_record_in_session = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record_in_session = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not summary_record_in_session:
|
||||
@ -710,8 +734,13 @@ class SummaryIndexService:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate summary for segment %s", segment.id)
|
||||
# Update summary record with error status
|
||||
summary_record_in_session = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record_in_session = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if summary_record_in_session:
|
||||
summary_record_in_session.status = SummaryStatus.ERROR
|
||||
@ -769,17 +798,17 @@ class SummaryIndexService:
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
# Query segments (only enabled segments)
|
||||
query = session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True, # Only generate summaries for enabled segments
|
||||
stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled.is_(True), # Only generate summaries for enabled segments
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegment.id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegment.id.in_(segment_ids))
|
||||
|
||||
segments = query.all()
|
||||
segments = list(session.scalars(stmt).all())
|
||||
|
||||
if not segments:
|
||||
logger.info("No segments found for document %s", document.id)
|
||||
@ -848,15 +877,15 @@ class SummaryIndexService:
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
query = session.query(DocumentSegmentSummary).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
enabled=True, # Only disable enabled summaries
|
||||
stmt = select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only disable enabled summaries
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
summaries = session.scalars(stmt).all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
@ -911,15 +940,15 @@ class SummaryIndexService:
|
||||
return
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
query = session.query(DocumentSegmentSummary).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
enabled=False, # Only enable disabled summaries
|
||||
stmt = select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
DocumentSegmentSummary.enabled.is_(False), # Only enable disabled summaries
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
summaries = session.scalars(stmt).all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
@ -935,13 +964,13 @@ class SummaryIndexService:
|
||||
enabled_count = 0
|
||||
for summary in summaries:
|
||||
# Get the original segment
|
||||
segment = (
|
||||
session.query(DocumentSegment)
|
||||
.filter_by(
|
||||
id=summary.chunk_id,
|
||||
dataset_id=dataset.id,
|
||||
segment = session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.id == summary.chunk_id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# Summary.enabled stays in sync with chunk.enabled,
|
||||
@ -988,12 +1017,12 @@ class SummaryIndexService:
|
||||
segment_ids: List of segment IDs to delete summaries for. If None, delete all.
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id)
|
||||
stmt = select(DocumentSegmentSummary).where(DocumentSegmentSummary.dataset_id == dataset.id)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
summaries = session.scalars(stmt).all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
@ -1046,10 +1075,13 @@ class SummaryIndexService:
|
||||
# Check if summary_content is empty (whitespace-only strings are considered empty)
|
||||
if not summary_content or not summary_content.strip():
|
||||
# If summary is empty, only delete existing summary vector and record
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
|
||||
.first()
|
||||
summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if summary_record:
|
||||
@ -1077,8 +1109,13 @@ class SummaryIndexService:
|
||||
return None
|
||||
|
||||
# Find existing summary record
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if summary_record:
|
||||
@ -1162,8 +1199,13 @@ class SummaryIndexService:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update summary for segment %s", segment.id)
|
||||
# Update summary record with error status if it exists
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if summary_record:
|
||||
summary_record.status = SummaryStatus.ERROR
|
||||
@ -1185,14 +1227,14 @@ class SummaryIndexService:
|
||||
DocumentSegmentSummary instance if found, None otherwise
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
return (
|
||||
session.query(DocumentSegmentSummary)
|
||||
return session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment_id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -1211,15 +1253,13 @@ class SummaryIndexService:
|
||||
return {}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
summary_records = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
summary_records = session.scalars(
|
||||
select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return {summary.chunk_id: summary for summary in summary_records}
|
||||
|
||||
@ -1239,16 +1279,16 @@ class SummaryIndexService:
|
||||
List of DocumentSegmentSummary instances (only enabled summaries)
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
query = session.query(DocumentSegmentSummary).filter(
|
||||
stmt = select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.document_id == document_id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
return query.all()
|
||||
return list(session.scalars(stmt).all())
|
||||
|
||||
@staticmethod
|
||||
def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None:
|
||||
@ -1265,16 +1305,15 @@ class SummaryIndexService:
|
||||
"""
|
||||
# Get all segments for this document (excluding qa_model and re_segment)
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment.id)
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.tenant_id == tenant_id,
|
||||
)
|
||||
.all()
|
||||
segment_ids = list(
|
||||
session.scalars(
|
||||
select(DocumentSegment.id).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.tenant_id == tenant_id,
|
||||
)
|
||||
).all()
|
||||
)
|
||||
segment_ids = [seg.id for seg in segments]
|
||||
|
||||
if not segment_ids:
|
||||
return None
|
||||
@ -1312,15 +1351,13 @@ class SummaryIndexService:
|
||||
|
||||
# Get all segments for these documents (excluding qa_model and re_segment)
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment.id, DocumentSegment.document_id)
|
||||
.where(
|
||||
segments = session.execute(
|
||||
select(DocumentSegment.id, DocumentSegment.document_id).where(
|
||||
DocumentSegment.document_id.in_(document_ids),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.tenant_id == tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
# Group segments by document_id
|
||||
document_segments_map: dict[str, list[str]] = {}
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy import delete, exists, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -47,11 +47,15 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(ToolOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
).delete()
|
||||
session.execute(
|
||||
delete(ToolOAuthTenantClient)
|
||||
.where(
|
||||
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@ -151,13 +155,13 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get if the provider exists
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
db_provider = session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if db_provider is None:
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
@ -228,7 +232,13 @@ class BuiltinToolManageService:
|
||||
raise ValueError(f"provider {provider} does not need credentials")
|
||||
|
||||
provider_count = (
|
||||
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
|
||||
session.scalar(
|
||||
select(func.count(BuiltinToolProvider.id)).where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
# check if the provider count is reached the limit
|
||||
@ -304,16 +314,15 @@ class BuiltinToolManageService:
|
||||
def generate_builtin_tool_provider_name(
|
||||
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
|
||||
) -> str:
|
||||
db_providers = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credential_type=credential_type,
|
||||
db_providers = session.scalars(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.credential_type == credential_type,
|
||||
)
|
||||
.order_by(BuiltinToolProvider.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
return generate_incremental_name(
|
||||
[provider.name for provider in db_providers],
|
||||
f"{credential_type.get_name()}",
|
||||
@ -375,13 +384,13 @@ class BuiltinToolManageService:
|
||||
delete tool provider
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
db_provider = session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if db_provider is None:
|
||||
@ -405,14 +414,26 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get provider
|
||||
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
|
||||
target_provider = session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(BuiltinToolProvider.id == id, BuiltinToolProvider.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# clear default provider
|
||||
session.query(BuiltinToolProvider).filter_by(
|
||||
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
|
||||
).update({"is_default": False})
|
||||
session.execute(
|
||||
update(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.user_id == user_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.is_default.is_(True),
|
||||
)
|
||||
.values(is_default=False)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
@ -426,10 +447,13 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider_name)
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
system_client: ToolOAuthSystemClient | None = (
|
||||
session.query(ToolOAuthSystemClient)
|
||||
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
||||
.first()
|
||||
system_client = session.scalar(
|
||||
select(ToolOAuthSystemClient)
|
||||
.where(
|
||||
ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id,
|
||||
ToolOAuthSystemClient.provider == tool_provider.provider_name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return system_client is not None
|
||||
|
||||
@ -440,15 +464,15 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
user_client: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
enabled=True,
|
||||
user_client = session.scalar(
|
||||
select(ToolOAuthTenantClient)
|
||||
.where(
|
||||
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||
ToolOAuthTenantClient.enabled.is_(True),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return user_client is not None and user_client.enabled
|
||||
|
||||
@ -465,15 +489,15 @@ class BuiltinToolManageService:
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
user_client: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
enabled=True,
|
||||
user_client = session.scalar(
|
||||
select(ToolOAuthTenantClient)
|
||||
.where(
|
||||
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||
ToolOAuthTenantClient.enabled.is_(True),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
oauth_params: Mapping[str, Any] | None = None
|
||||
if user_client:
|
||||
@ -487,10 +511,13 @@ class BuiltinToolManageService:
|
||||
if not is_verified:
|
||||
return oauth_params
|
||||
|
||||
system_client: ToolOAuthSystemClient | None = (
|
||||
session.query(ToolOAuthSystemClient)
|
||||
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
||||
.first()
|
||||
system_client = session.scalar(
|
||||
select(ToolOAuthSystemClient)
|
||||
.where(
|
||||
ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id,
|
||||
ToolOAuthSystemClient.provider == tool_provider.provider_name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if system_client:
|
||||
try:
|
||||
@ -582,8 +609,8 @@ class BuiltinToolManageService:
|
||||
provider_name = provider_id_entity.provider_name
|
||||
|
||||
if provider_id_entity.organization != "langgenius":
|
||||
provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
provider = session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == full_provider_name,
|
||||
@ -592,11 +619,11 @@ class BuiltinToolManageService:
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
provider = session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == provider_name)
|
||||
@ -606,7 +633,7 @@ class BuiltinToolManageService:
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
@ -616,14 +643,14 @@ class BuiltinToolManageService:
|
||||
return provider
|
||||
except Exception:
|
||||
# it's an old provider without organization
|
||||
return (
|
||||
session.query(BuiltinToolProvider)
|
||||
return session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
|
||||
.order_by(
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -648,14 +675,14 @@ class BuiltinToolManageService:
|
||||
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
custom_client_params = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=tool_provider.provider_name,
|
||||
custom_client_params = session.scalar(
|
||||
select(ToolOAuthTenantClient)
|
||||
.where(
|
||||
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# if the record does not exist, create a basic record
|
||||
@ -692,14 +719,14 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
tool_provider = ToolProviderID(provider)
|
||||
custom_oauth_client_params: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=tool_provider.provider_name,
|
||||
custom_oauth_client_params = session.scalar(
|
||||
select(ToolOAuthTenantClient)
|
||||
.where(
|
||||
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if custom_oauth_client_params is None:
|
||||
return {}
|
||||
|
||||
@ -84,6 +84,7 @@ class WorkflowToolManageService:
|
||||
try:
|
||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||
except Exception as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
raise ValueError(str(e))
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
|
||||
@ -3,9 +3,9 @@ import logging
|
||||
import time as _time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy import delete, desc, func, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -42,6 +42,10 @@ from services.plugin.plugin_service import PluginService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VerifyCredentialsResult(TypedDict):
|
||||
verified: bool
|
||||
|
||||
|
||||
class TriggerProviderService:
|
||||
"""Service for managing trigger providers and credentials"""
|
||||
|
||||
@ -69,27 +73,28 @@ class TriggerProviderService:
|
||||
workflows_in_use_map: dict[str, int] = {}
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get all subscriptions
|
||||
subscriptions_db = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
||||
subscriptions_db = session.scalars(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.provider_id == str(provider_id),
|
||||
)
|
||||
.order_by(desc(TriggerSubscription.created_at))
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db]
|
||||
if not subscriptions:
|
||||
return []
|
||||
usage_counts = (
|
||||
session.query(
|
||||
usage_counts = session.execute(
|
||||
select(
|
||||
WorkflowPluginTrigger.subscription_id,
|
||||
func.count(func.distinct(WorkflowPluginTrigger.app_id)).label("app_count"),
|
||||
)
|
||||
.filter(
|
||||
.where(
|
||||
WorkflowPluginTrigger.tenant_id == tenant_id,
|
||||
WorkflowPluginTrigger.subscription_id.in_([s.id for s in subscriptions]),
|
||||
)
|
||||
.group_by(WorkflowPluginTrigger.subscription_id)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
workflows_in_use_map = {str(row.subscription_id): int(row.app_count) for row in usage_counts}
|
||||
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
@ -152,9 +157,13 @@ class TriggerProviderService:
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
# Check provider count limit
|
||||
provider_count = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
||||
.count()
|
||||
session.scalar(
|
||||
select(func.count(TriggerSubscription.id)).where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.provider_id == str(provider_id),
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__:
|
||||
@ -164,10 +173,14 @@ class TriggerProviderService:
|
||||
)
|
||||
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||
.first()
|
||||
existing = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.provider_id == str(provider_id),
|
||||
TriggerSubscription.name == name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists for this provider")
|
||||
@ -244,8 +257,13 @@ class TriggerProviderService:
|
||||
# Use distributed lock to prevent race conditions on the same subscription
|
||||
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.id == subscription_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Trigger subscription {subscription_id} not found")
|
||||
@ -255,10 +273,14 @@ class TriggerProviderService:
|
||||
|
||||
# Check for name uniqueness if name is being updated
|
||||
if name is not None and name != subscription.name:
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||
.first()
|
||||
existing = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.provider_id == str(provider_id),
|
||||
TriggerSubscription.name == name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Subscription name '{name}' already exists for this provider")
|
||||
@ -316,11 +338,18 @@ class TriggerProviderService:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
subscription: TriggerSubscription | None = None
|
||||
if subscription_id:
|
||||
subscription = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.id == subscription_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
if subscription:
|
||||
provider_controller = TriggerManager.get_trigger_provider(
|
||||
tenant_id, TriggerProviderID(subscription.provider_id)
|
||||
@ -349,8 +378,13 @@ class TriggerProviderService:
|
||||
:param subscription_id: Subscription instance ID
|
||||
:return: Success response
|
||||
"""
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.id == subscription_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
@ -402,7 +436,14 @@ class TriggerProviderService:
|
||||
:return: New token info
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.id == subscription_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not subscription:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
@ -475,8 +516,13 @@ class TriggerProviderService:
|
||||
now_ts: int = int(now if now is not None else _time.time())
|
||||
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.id == subscription_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if subscription is None:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
@ -552,15 +598,15 @@ class TriggerProviderService:
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
tenant_client: TriggerOAuthTenantClient | None = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
enabled=True,
|
||||
tenant_client = session.scalar(
|
||||
select(TriggerOAuthTenantClient)
|
||||
.where(
|
||||
TriggerOAuthTenantClient.tenant_id == tenant_id,
|
||||
TriggerOAuthTenantClient.provider == provider_id.provider_name,
|
||||
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
|
||||
TriggerOAuthTenantClient.enabled.is_(True),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
oauth_params: Mapping[str, Any] | None = None
|
||||
@ -578,10 +624,13 @@ class TriggerProviderService:
|
||||
return None
|
||||
|
||||
# Check for system-level OAuth client
|
||||
system_client: TriggerOAuthSystemClient | None = (
|
||||
session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
|
||||
.first()
|
||||
system_client = session.scalar(
|
||||
select(TriggerOAuthSystemClient)
|
||||
.where(
|
||||
TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id,
|
||||
TriggerOAuthSystemClient.provider == provider_id.provider_name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if system_client:
|
||||
@ -602,10 +651,13 @@ class TriggerProviderService:
|
||||
if not is_verified:
|
||||
return False
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
system_client: TriggerOAuthSystemClient | None = (
|
||||
session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
|
||||
.first()
|
||||
system_client = session.scalar(
|
||||
select(TriggerOAuthSystemClient)
|
||||
.where(
|
||||
TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id,
|
||||
TriggerOAuthSystemClient.provider == provider_id.provider_name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return system_client is not None
|
||||
|
||||
@ -636,14 +688,14 @@ class TriggerProviderService:
|
||||
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# Find existing custom client params
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
custom_client = session.scalar(
|
||||
select(TriggerOAuthTenantClient)
|
||||
.where(
|
||||
TriggerOAuthTenantClient.tenant_id == tenant_id,
|
||||
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
|
||||
TriggerOAuthTenantClient.provider == provider_id.provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# Create new record if doesn't exist
|
||||
@ -690,14 +742,14 @@ class TriggerProviderService:
|
||||
:return: Masked OAuth client parameters
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
custom_client = session.scalar(
|
||||
select(TriggerOAuthTenantClient)
|
||||
.where(
|
||||
TriggerOAuthTenantClient.tenant_id == tenant_id,
|
||||
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
|
||||
TriggerOAuthTenantClient.provider == provider_id.provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if custom_client is None:
|
||||
@ -727,11 +779,15 @@ class TriggerProviderService:
|
||||
:return: Success response
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(TriggerOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
).delete()
|
||||
session.execute(
|
||||
delete(TriggerOAuthTenantClient)
|
||||
.where(
|
||||
TriggerOAuthTenantClient.tenant_id == tenant_id,
|
||||
TriggerOAuthTenantClient.provider == provider_id.provider_name,
|
||||
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -745,15 +801,15 @@ class TriggerProviderService:
|
||||
:return: True if enabled, False otherwise
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
enabled=True,
|
||||
custom_client = session.scalar(
|
||||
select(TriggerOAuthTenantClient)
|
||||
.where(
|
||||
TriggerOAuthTenantClient.tenant_id == tenant_id,
|
||||
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
|
||||
TriggerOAuthTenantClient.provider == provider_id.provider_name,
|
||||
TriggerOAuthTenantClient.enabled.is_(True),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return custom_client is not None
|
||||
|
||||
@ -763,7 +819,9 @@ class TriggerProviderService:
|
||||
Get a trigger subscription by the endpoint ID.
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription).where(TriggerSubscription.endpoint_id == endpoint_id).limit(1)
|
||||
)
|
||||
if not subscription:
|
||||
return None
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
@ -792,7 +850,7 @@ class TriggerProviderService:
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_id: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
) -> VerifyCredentialsResult:
|
||||
"""
|
||||
Verify credentials for an existing subscription without updating it.
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ from graphon.variables.segments import (
|
||||
)
|
||||
from graphon.variables.types import SegmentType
|
||||
from graphon.variables.utils import dumps_with_segments
|
||||
from sqlalchemy import Engine, orm, select
|
||||
from sqlalchemy import Engine, delete, orm, select
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@ -222,11 +222,10 @@ class WorkflowDraftVariableService:
|
||||
)
|
||||
|
||||
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
|
||||
return (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
return self._session.scalar(
|
||||
select(WorkflowDraftVariable)
|
||||
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
|
||||
.where(WorkflowDraftVariable.id == variable_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_draft_variables_by_selectors(
|
||||
@ -254,20 +253,21 @@ class WorkflowDraftVariableService:
|
||||
# Alternatively, a `SELECT` statement could be constructed for each selector and
|
||||
# combined using `UNION` to fetch all rows.
|
||||
# Benchmarking indicates that both approaches yield comparable performance.
|
||||
query = (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
.options(
|
||||
orm.selectinload(WorkflowDraftVariable.variable_file).selectinload(
|
||||
WorkflowDraftVariableFile.upload_file
|
||||
return list(
|
||||
self._session.scalars(
|
||||
select(WorkflowDraftVariable)
|
||||
.options(
|
||||
orm.selectinload(WorkflowDraftVariable.variable_file).selectinload(
|
||||
WorkflowDraftVariableFile.upload_file
|
||||
)
|
||||
)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
or_(*ors),
|
||||
)
|
||||
)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
or_(*ors),
|
||||
)
|
||||
)
|
||||
return query.all()
|
||||
|
||||
def list_variables_without_values(
|
||||
self, app_id: str, page: int, limit: int, user_id: str
|
||||
@ -277,18 +277,21 @@ class WorkflowDraftVariableService:
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
]
|
||||
total = None
|
||||
query = self._session.query(WorkflowDraftVariable).where(*criteria)
|
||||
base_stmt = select(WorkflowDraftVariable).where(*criteria)
|
||||
if page == 1:
|
||||
total = query.count()
|
||||
variables = (
|
||||
# Do not load the `value` field
|
||||
query.options(
|
||||
orm.defer(WorkflowDraftVariable.value, raiseload=True),
|
||||
from sqlalchemy import func as sa_func
|
||||
|
||||
total = self._session.scalar(select(sa_func.count()).select_from(base_stmt.subquery()))
|
||||
variables = list(
|
||||
self._session.scalars(
|
||||
# Do not load the `value` field
|
||||
base_stmt.options(
|
||||
orm.defer(WorkflowDraftVariable.value, raiseload=True),
|
||||
)
|
||||
.order_by(WorkflowDraftVariable.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset((page - 1) * limit)
|
||||
)
|
||||
.order_by(WorkflowDraftVariable.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset((page - 1) * limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return WorkflowDraftVariableList(variables=variables, total=total)
|
||||
@ -299,11 +302,13 @@ class WorkflowDraftVariableService:
|
||||
WorkflowDraftVariable.node_id == node_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
]
|
||||
query = self._session.query(WorkflowDraftVariable).where(*criteria)
|
||||
variables = (
|
||||
query.options(orm.selectinload(WorkflowDraftVariable.variable_file))
|
||||
.order_by(WorkflowDraftVariable.created_at.desc())
|
||||
.all()
|
||||
variables = list(
|
||||
self._session.scalars(
|
||||
select(WorkflowDraftVariable)
|
||||
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
|
||||
.where(*criteria)
|
||||
.order_by(WorkflowDraftVariable.created_at.desc())
|
||||
)
|
||||
)
|
||||
return WorkflowDraftVariableList(variables=variables)
|
||||
|
||||
@ -326,8 +331,8 @@ class WorkflowDraftVariableService:
|
||||
return self._get_variable(app_id, node_id, name, user_id=user_id)
|
||||
|
||||
def _get_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
|
||||
return (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
return self._session.scalar(
|
||||
select(WorkflowDraftVariable)
|
||||
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
@ -335,7 +340,6 @@ class WorkflowDraftVariableService:
|
||||
WorkflowDraftVariable.name == name,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def update_variable(
|
||||
@ -488,20 +492,20 @@ class WorkflowDraftVariableService:
|
||||
self._session.delete(variable)
|
||||
|
||||
def delete_user_workflow_variables(self, app_id: str, user_id: str):
|
||||
(
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
self._session.execute(
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
)
|
||||
.delete(synchronize_session=False)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
def delete_app_workflow_variables(self, app_id: str):
|
||||
(
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
self._session.execute(
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.app_id == app_id)
|
||||
.delete(synchronize_session=False)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
def delete_workflow_draft_variable_file(self, deletions: list[DraftVarFileDeletion]):
|
||||
@ -540,14 +544,14 @@ class WorkflowDraftVariableService:
|
||||
return self._delete_node_variables(app_id, node_id, user_id=user_id)
|
||||
|
||||
def _delete_node_variables(self, app_id: str, node_id: str, user_id: str):
|
||||
(
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
self._session.execute(
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.node_id == node_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
)
|
||||
.delete(synchronize_session=False)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
def _get_conversation_id_from_draft_variable(self, app_id: str, user_id: str) -> str | None:
|
||||
@ -588,13 +592,11 @@ class WorkflowDraftVariableService:
|
||||
conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id, account_id)
|
||||
|
||||
if conv_id is not None:
|
||||
conversation = (
|
||||
self._session.query(Conversation)
|
||||
.where(
|
||||
conversation = self._session.scalar(
|
||||
select(Conversation).where(
|
||||
Conversation.id == conv_id,
|
||||
Conversation.app_id == workflow.app_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
# Only return the conversation ID if it exists and is valid (has a correspond conversation record in DB).
|
||||
if conversation is not None:
|
||||
|
||||
@ -1512,14 +1512,12 @@ class WorkflowService:
|
||||
|
||||
# Don't use workflow.tool_published as it's not accurate for specific workflow versions
|
||||
# Check if there's a tool provider using this specific workflow version
|
||||
tool_provider = (
|
||||
session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
tool_provider = session.scalar(
|
||||
select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == workflow.tenant_id,
|
||||
WorkflowToolProvider.app_id == workflow.app_id,
|
||||
WorkflowToolProvider.version == workflow.version,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if tool_provider:
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Any, Protocol
|
||||
|
||||
import click
|
||||
from celery import current_app, shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
@ -53,11 +54,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
|
||||
Usage: _document_indexing(dataset_id, document_ids)
|
||||
"""
|
||||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
|
||||
return
|
||||
@ -79,8 +79,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
)
|
||||
except Exception as e:
|
||||
for document_id in document_ids:
|
||||
document = (
|
||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
document = session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
@ -92,8 +92,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
|
||||
# Phase 1: Update status to parsing (short transaction)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
documents = (
|
||||
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
|
||||
documents: list[Document] = list(
|
||||
session.scalars(
|
||||
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||
).all()
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
@ -122,7 +124,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
# Trigger summary index generation for completed documents if enabled
|
||||
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
||||
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
logger.warning("Dataset %s not found after indexing", dataset_id)
|
||||
return
|
||||
@ -134,10 +136,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
session.expire_all()
|
||||
# Check each document's indexing status and trigger summary generation if completed
|
||||
|
||||
documents = (
|
||||
session.query(Document)
|
||||
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||
.all()
|
||||
documents = list(
|
||||
session.scalars(
|
||||
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||
).all()
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Any, cast
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from celery import shared_task
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
@ -99,7 +99,11 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_model_configs(tenant_id: str, app_id: str):
|
||||
def del_model_config(session, model_config_id: str):
|
||||
session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(AppModelConfig)
|
||||
.where(AppModelConfig.id == model_config_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_model_configs where app_id=:app_id limit 1000""",
|
||||
@ -111,7 +115,7 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_site(tenant_id: str, app_id: str):
|
||||
def del_site(session, site_id: str):
|
||||
session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
|
||||
session.execute(delete(Site).where(Site.id == site_id).execution_options(synchronize_session=False))
|
||||
|
||||
_delete_records(
|
||||
"""select id from sites where app_id=:app_id limit 1000""",
|
||||
@ -123,7 +127,9 @@ def _delete_app_site(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
||||
def del_mcp_server(session, mcp_server_id: str):
|
||||
session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(AppMCPServer).where(AppMCPServer.id == mcp_server_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
|
||||
@ -136,12 +142,14 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
||||
def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
||||
def del_api_token(session, api_token_id: str):
|
||||
# Fetch token details for cache invalidation
|
||||
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
|
||||
token_obj = session.scalar(select(ApiToken).where(ApiToken.id == api_token_id).limit(1))
|
||||
if token_obj:
|
||||
# Invalidate cache before deletion
|
||||
ApiTokenCache.delete(token_obj.token, token_obj.type)
|
||||
|
||||
session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(ApiToken).where(ApiToken.id == api_token_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from api_tokens where app_id=:app_id limit 1000""",
|
||||
@ -153,7 +161,9 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_installed_apps(tenant_id: str, app_id: str):
|
||||
def del_installed_app(session, installed_app_id: str):
|
||||
session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(InstalledApp).where(InstalledApp.id == installed_app_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
@ -165,7 +175,11 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_recommended_apps(tenant_id: str, app_id: str):
|
||||
def del_recommended_app(session, recommended_app_id: str):
|
||||
session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(RecommendedApp)
|
||||
.where(RecommendedApp.id == recommended_app_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from recommended_apps where app_id=:app_id limit 1000""",
|
||||
@ -177,8 +191,10 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
||||
def del_annotation_hit_history(session, annotation_hit_history_id: str):
|
||||
session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(AppAnnotationHitHistory)
|
||||
.where(AppAnnotationHitHistory.id == annotation_hit_history_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
@ -189,8 +205,10 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
||||
)
|
||||
|
||||
def del_annotation_setting(session, annotation_setting_id: str):
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(AppAnnotationSetting)
|
||||
.where(AppAnnotationSetting.id == annotation_setting_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
@ -203,7 +221,11 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
|
||||
def del_dataset_join(session, dataset_join_id: str):
|
||||
session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(AppDatasetJoin)
|
||||
.where(AppDatasetJoin.id == dataset_join_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
|
||||
@ -215,7 +237,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_workflows(tenant_id: str, app_id: str):
|
||||
def del_workflow(session, workflow_id: str):
|
||||
session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
|
||||
session.execute(delete(Workflow).where(Workflow.id == workflow_id).execution_options(synchronize_session=False))
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
@ -255,7 +277,11 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_app_log(session, workflow_app_log_id: str):
|
||||
session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(WorkflowAppLog)
|
||||
.where(WorkflowAppLog.id == workflow_app_log_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
@ -267,8 +293,10 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_archive_log(session, workflow_archive_log_id: str):
|
||||
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(WorkflowArchiveLog)
|
||||
.where(WorkflowArchiveLog.id == workflow_archive_log_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
@ -306,10 +334,14 @@ def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_conversations(tenant_id: str, app_id: str):
|
||||
def del_conversation(session, conversation_id: str):
|
||||
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(PinnedConversation)
|
||||
.where(PinnedConversation.conversation_id == conversation_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(
|
||||
delete(Conversation).where(Conversation.id == conversation_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from conversations where app_id=:app_id limit 1000""",
|
||||
@ -329,17 +361,35 @@ def _delete_conversation_variables(*, app_id: str):
|
||||
|
||||
def _delete_app_messages(tenant_id: str, app_id: str):
|
||||
def del_message(session, message_id: str):
|
||||
session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
|
||||
session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == message_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
|
||||
session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(MessageAnnotation)
|
||||
.where(MessageAnnotation.message_id == message_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
|
||||
session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
|
||||
session.query(Message).where(Message.id == message_id).delete()
|
||||
session.execute(
|
||||
delete(MessageChain)
|
||||
.where(MessageChain.message_id == message_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(
|
||||
delete(MessageAgentThought)
|
||||
.where(MessageAgentThought.message_id == message_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(
|
||||
delete(MessageFile).where(MessageFile.message_id == message_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(
|
||||
delete(SavedMessage)
|
||||
.where(SavedMessage.message_id == message_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(delete(Message).where(Message.id == message_id).execution_options(synchronize_session=False))
|
||||
|
||||
_delete_records(
|
||||
"""select id from messages where app_id=:app_id limit 1000""",
|
||||
@ -351,8 +401,10 @@ def _delete_app_messages(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
|
||||
def del_tool_provider(session, tool_provider_id: str):
|
||||
session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.id == tool_provider_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
@ -365,7 +417,9 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
|
||||
def del_tag_binding(session, tag_binding_id: str):
|
||||
session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(TagBinding).where(TagBinding.id == tag_binding_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
|
||||
@ -377,7 +431,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_end_users(tenant_id: str, app_id: str):
|
||||
def del_end_user(session, end_user_id: str):
|
||||
session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
|
||||
session.execute(delete(EndUser).where(EndUser.id == end_user_id).execution_options(synchronize_session=False))
|
||||
|
||||
_delete_records(
|
||||
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
@ -389,7 +443,11 @@ def _delete_end_users(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_trace_app_configs(tenant_id: str, app_id: str):
|
||||
def del_trace_app_config(session, trace_app_config_id: str):
|
||||
session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(TraceAppConfig)
|
||||
.where(TraceAppConfig.id == trace_app_config_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from trace_app_config where app_id=:app_id limit 1000""",
|
||||
@ -545,7 +603,9 @@ def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
|
||||
|
||||
def _delete_app_triggers(tenant_id: str, app_id: str):
|
||||
def del_app_trigger(session, trigger_id: str):
|
||||
session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(AppTrigger).where(AppTrigger.id == trigger_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
@ -557,8 +617,10 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
|
||||
def del_plugin_trigger(session, trigger_id: str):
|
||||
session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(WorkflowPluginTrigger)
|
||||
.where(WorkflowPluginTrigger.id == trigger_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
@ -571,8 +633,10 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
|
||||
def del_webhook_trigger(session, trigger_id: str):
|
||||
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(WorkflowWebhookTrigger)
|
||||
.where(WorkflowWebhookTrigger.id == trigger_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
@ -585,7 +649,11 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
|
||||
def del_schedule_plan(session, plan_id: str):
|
||||
session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(WorkflowSchedulePlan)
|
||||
.where(WorkflowSchedulePlan.id == plan_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
@ -597,7 +665,11 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
|
||||
def del_trigger_log(session, log_id: str):
|
||||
session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(WorkflowTriggerLog)
|
||||
.where(WorkflowTriggerLog.id == log_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
|
||||
@ -557,11 +557,9 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
self.session.refresh(self.test_workflow_run)
|
||||
assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
|
||||
pause_states = (
|
||||
self.session.query(WorkflowPauseModel)
|
||||
.filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
pause_states = self.session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id)
|
||||
).all()
|
||||
assert len(pause_states) == 0
|
||||
|
||||
def test_layer_requires_initialization(self, db_session_with_containers):
|
||||
|
||||
@ -0,0 +1,149 @@
|
||||
"""
|
||||
Integration tests for Conversation.inputs and Message.inputs tenant resolution.
|
||||
|
||||
Migrated from unit_tests/models/test_model.py, replacing db.session.scalar monkeypatching
|
||||
with a real App in PostgreSQL so the _resolve_app_tenant_id lookup executes against the DB.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from models.model import App, AppMode, Conversation, Message
|
||||
|
||||
|
||||
def _build_local_file_mapping(record_id: str, *, tenant_id: str | None = None) -> dict:
|
||||
mapping: dict = {
|
||||
"dify_model_identity": FILE_MODEL_IDENTITY,
|
||||
"transfer_method": FileTransferMethod.LOCAL_FILE,
|
||||
"reference": build_file_reference(record_id=record_id),
|
||||
"type": "document",
|
||||
"filename": "example.txt",
|
||||
"extension": ".txt",
|
||||
"mime_type": "text/plain",
|
||||
"size": 1,
|
||||
}
|
||||
if tenant_id is not None:
|
||||
mapping["tenant_id"] = tenant_id
|
||||
return mapping
|
||||
|
||||
|
||||
class TestConversationMessageInputsTenantResolution:
|
||||
"""Integration tests for Conversation/Message.inputs tenant resolution via real DB lookup."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
|
||||
"""Automatically rollback session changes after each test."""
|
||||
yield
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
def _create_app(self, db_session: Session) -> App:
|
||||
tenant_id = str(uuid4())
|
||||
app = App(
|
||||
tenant_id=tenant_id,
|
||||
name=f"App {uuid4()}",
|
||||
mode=AppMode.CHAT,
|
||||
enable_site=False,
|
||||
enable_api=True,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=str(uuid4()),
|
||||
updated_by=str(uuid4()),
|
||||
)
|
||||
db_session.add(app)
|
||||
db_session.flush()
|
||||
return app
|
||||
|
||||
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
|
||||
def test_inputs_resolves_tenant_via_db_for_local_file(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
owner_cls: type,
|
||||
) -> None:
|
||||
"""Inputs resolves tenant_id from real App row when file mapping has no tenant_id."""
|
||||
app = self._create_app(db_session_with_containers)
|
||||
build_calls: list[tuple[dict, str]] = []
|
||||
|
||||
def fake_build_from_mapping(
|
||||
*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller
|
||||
):
|
||||
build_calls.append((dict(mapping), tenant_id))
|
||||
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
|
||||
|
||||
with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping):
|
||||
owner = owner_cls(app_id=app.id)
|
||||
owner.inputs = {"file": _build_local_file_mapping("upload-1")}
|
||||
|
||||
restored_inputs = owner.inputs
|
||||
|
||||
# The tenant_id should come from the real App row in the DB
|
||||
assert restored_inputs["file"] == {"tenant_id": app.tenant_id, "upload_file_id": "upload-1"}
|
||||
assert len(build_calls) == 1
|
||||
assert build_calls[0][1] == app.tenant_id
|
||||
|
||||
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
|
||||
def test_inputs_uses_serialized_tenant_id_skipping_db_lookup(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
owner_cls: type,
|
||||
) -> None:
|
||||
"""Inputs uses tenant_id from the file mapping payload without hitting the DB."""
|
||||
app = self._create_app(db_session_with_containers)
|
||||
payload_tenant_id = "tenant-from-payload"
|
||||
build_calls: list[tuple[dict, str]] = []
|
||||
|
||||
def fake_build_from_mapping(
|
||||
*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller
|
||||
):
|
||||
build_calls.append((dict(mapping), tenant_id))
|
||||
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
|
||||
|
||||
with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping):
|
||||
owner = owner_cls(app_id=app.id)
|
||||
owner.inputs = {"file": _build_local_file_mapping("upload-1", tenant_id=payload_tenant_id)}
|
||||
|
||||
restored_inputs = owner.inputs
|
||||
|
||||
assert restored_inputs["file"] == {"tenant_id": payload_tenant_id, "upload_file_id": "upload-1"}
|
||||
assert len(build_calls) == 1
|
||||
assert build_calls[0][1] == payload_tenant_id
|
||||
|
||||
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
|
||||
def test_inputs_resolves_tenant_for_file_list(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
owner_cls: type,
|
||||
) -> None:
|
||||
"""Inputs resolves tenant_id for a list of file mappings."""
|
||||
app = self._create_app(db_session_with_containers)
|
||||
build_calls: list[tuple[dict, str]] = []
|
||||
|
||||
def fake_build_from_mapping(
|
||||
*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller
|
||||
):
|
||||
build_calls.append((dict(mapping), tenant_id))
|
||||
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
|
||||
|
||||
with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping):
|
||||
owner = owner_cls(app_id=app.id)
|
||||
owner.inputs = {
|
||||
"files": [
|
||||
_build_local_file_mapping("upload-1"),
|
||||
_build_local_file_mapping("upload-2"),
|
||||
]
|
||||
}
|
||||
|
||||
restored_inputs = owner.inputs
|
||||
|
||||
assert len(build_calls) == 2
|
||||
assert all(call[1] == app.tenant_id for call in build_calls)
|
||||
assert restored_inputs["files"] == [
|
||||
{"tenant_id": app.tenant_id, "upload_file_id": "upload-1"},
|
||||
{"tenant_id": app.tenant_id, "upload_file_id": "upload-2"},
|
||||
]
|
||||
@ -0,0 +1,314 @@
|
||||
"""
|
||||
Integration tests for Conversation.status_count and Site.generate_code model properties.
|
||||
|
||||
Migrated from unit_tests/models/test_app_models.py TestConversationStatusCount and
|
||||
test_site_generate_code, replacing db.session.scalars mocks with real PostgreSQL queries.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.enums import ConversationFromSource, InvokeFrom
|
||||
from models.model import App, AppMode, Conversation, Message, Site
|
||||
from models.workflow import Workflow, WorkflowRun, WorkflowRunTriggeredFrom, WorkflowType
|
||||
|
||||
|
||||
class TestConversationStatusCount:
|
||||
"""Integration tests for Conversation.status_count property."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
|
||||
"""Automatically rollback session changes after each test."""
|
||||
yield
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
def _create_app(self, db_session: Session, tenant_id: str, created_by: str) -> App:
|
||||
app = App(
|
||||
tenant_id=tenant_id,
|
||||
name=f"App {uuid4()}",
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
enable_site=False,
|
||||
enable_api=True,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=created_by,
|
||||
updated_by=created_by,
|
||||
)
|
||||
db_session.add(app)
|
||||
db_session.flush()
|
||||
return app
|
||||
|
||||
def _create_conversation(self, db_session: Session, app: App) -> Conversation:
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
mode=app.mode,
|
||||
name=f"Conversation {uuid4()}",
|
||||
summary="",
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
from_source=ConversationFromSource.API,
|
||||
dialogue_count=0,
|
||||
is_deleted=False,
|
||||
)
|
||||
conversation.inputs = {}
|
||||
db_session.add(conversation)
|
||||
db_session.flush()
|
||||
return conversation
|
||||
|
||||
def _create_workflow(self, db_session: Session, app: App, created_by: str) -> Workflow:
|
||||
workflow = Workflow(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type=WorkflowType.CHAT,
|
||||
version="draft",
|
||||
graph="{}",
|
||||
created_by=created_by,
|
||||
)
|
||||
workflow._features = "{}"
|
||||
db_session.add(workflow)
|
||||
db_session.flush()
|
||||
return workflow
|
||||
|
||||
def _create_workflow_run(
|
||||
self, db_session: Session, app: App, workflow: Workflow, status: WorkflowExecutionStatus, created_by: str
|
||||
) -> WorkflowRun:
|
||||
run = WorkflowRun(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
type=WorkflowType.CHAT,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="draft",
|
||||
status=status,
|
||||
created_by_role="account",
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session.add(run)
|
||||
db_session.flush()
|
||||
return run
|
||||
|
||||
def _create_message(
|
||||
self, db_session: Session, app: App, conversation: Conversation, workflow_run_id: str | None = None
|
||||
) -> Message:
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
_inputs={},
|
||||
query="Test query",
|
||||
message={"role": "user", "content": "Test query"},
|
||||
answer="Test answer",
|
||||
model_provider="openai",
|
||||
model_id="gpt-4",
|
||||
message_tokens=10,
|
||||
message_unit_price=0,
|
||||
answer_tokens=10,
|
||||
answer_unit_price=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
from_source=ConversationFromSource.API,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.flush()
|
||||
return message
|
||||
|
||||
def test_status_count_returns_none_when_no_messages(self, db_session_with_containers: Session) -> None:
|
||||
"""status_count returns None when conversation has no messages with workflow_run_id."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
|
||||
app = self._create_app(db_session_with_containers, tenant_id, created_by)
|
||||
conversation = self._create_conversation(db_session_with_containers, app)
|
||||
|
||||
result = conversation.status_count
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_status_count_returns_none_when_messages_have_no_workflow_run_id(
|
||||
self, db_session_with_containers: Session
|
||||
) -> None:
|
||||
"""status_count returns None when messages exist but none have workflow_run_id."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
|
||||
app = self._create_app(db_session_with_containers, tenant_id, created_by)
|
||||
conversation = self._create_conversation(db_session_with_containers, app)
|
||||
self._create_message(db_session_with_containers, app, conversation, workflow_run_id=None)
|
||||
|
||||
result = conversation.status_count
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_status_count_counts_succeeded_workflow_run(self, db_session_with_containers: Session) -> None:
|
||||
"""status_count correctly counts succeeded workflow runs."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
|
||||
app = self._create_app(db_session_with_containers, tenant_id, created_by)
|
||||
conversation = self._create_conversation(db_session_with_containers, app)
|
||||
workflow = self._create_workflow(db_session_with_containers, app, created_by)
|
||||
run = self._create_workflow_run(
|
||||
db_session_with_containers, app, workflow, WorkflowExecutionStatus.SUCCEEDED, created_by
|
||||
)
|
||||
self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id)
|
||||
|
||||
result = conversation.status_count
|
||||
|
||||
assert result is not None
|
||||
assert result["success"] == 1
|
||||
assert result["failed"] == 0
|
||||
assert result["partial_success"] == 0
|
||||
assert result["paused"] == 0
|
||||
|
||||
def test_status_count_counts_failed_workflow_run(self, db_session_with_containers: Session) -> None:
|
||||
"""status_count correctly counts failed workflow runs."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
|
||||
app = self._create_app(db_session_with_containers, tenant_id, created_by)
|
||||
conversation = self._create_conversation(db_session_with_containers, app)
|
||||
workflow = self._create_workflow(db_session_with_containers, app, created_by)
|
||||
run = self._create_workflow_run(
|
||||
db_session_with_containers, app, workflow, WorkflowExecutionStatus.FAILED, created_by
|
||||
)
|
||||
self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id)
|
||||
|
||||
result = conversation.status_count
|
||||
|
||||
assert result is not None
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 1
|
||||
assert result["partial_success"] == 0
|
||||
assert result["paused"] == 0
|
||||
|
||||
def test_status_count_counts_paused_workflow_run(self, db_session_with_containers: Session) -> None:
|
||||
"""status_count correctly counts paused workflow runs."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
|
||||
app = self._create_app(db_session_with_containers, tenant_id, created_by)
|
||||
conversation = self._create_conversation(db_session_with_containers, app)
|
||||
workflow = self._create_workflow(db_session_with_containers, app, created_by)
|
||||
run = self._create_workflow_run(
|
||||
db_session_with_containers, app, workflow, WorkflowExecutionStatus.PAUSED, created_by
|
||||
)
|
||||
self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id)
|
||||
|
||||
result = conversation.status_count
|
||||
|
||||
assert result is not None
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["partial_success"] == 0
|
||||
assert result["paused"] == 1
|
||||
|
||||
def test_status_count_multiple_statuses(self, db_session_with_containers: Session) -> None:
|
||||
"""status_count counts multiple workflow runs with different statuses."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
|
||||
app = self._create_app(db_session_with_containers, tenant_id, created_by)
|
||||
conversation = self._create_conversation(db_session_with_containers, app)
|
||||
workflow = self._create_workflow(db_session_with_containers, app, created_by)
|
||||
|
||||
for status in [
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
WorkflowExecutionStatus.PAUSED,
|
||||
]:
|
||||
run = self._create_workflow_run(db_session_with_containers, app, workflow, status, created_by)
|
||||
self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id)
|
||||
|
||||
result = conversation.status_count
|
||||
|
||||
assert result is not None
|
||||
assert result["success"] == 1
|
||||
assert result["failed"] == 1
|
||||
assert result["partial_success"] == 1
|
||||
assert result["paused"] == 1
|
||||
|
||||
def test_status_count_filters_workflow_runs_by_app_id(self, db_session_with_containers: Session) -> None:
|
||||
"""status_count excludes workflow runs belonging to a different app."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
|
||||
app = self._create_app(db_session_with_containers, tenant_id, created_by)
|
||||
other_app = self._create_app(db_session_with_containers, tenant_id, created_by)
|
||||
conversation = self._create_conversation(db_session_with_containers, app)
|
||||
workflow = self._create_workflow(db_session_with_containers, other_app, created_by)
|
||||
|
||||
# Workflow run belongs to other_app, not app
|
||||
other_run = self._create_workflow_run(
|
||||
db_session_with_containers, other_app, workflow, WorkflowExecutionStatus.SUCCEEDED, created_by
|
||||
)
|
||||
# Message references that run but is in a conversation under app
|
||||
self._create_message(db_session_with_containers, app, conversation, workflow_run_id=other_run.id)
|
||||
|
||||
result = conversation.status_count
|
||||
|
||||
# The run should be excluded because app_id filter doesn't match
|
||||
assert result is not None
|
||||
assert result["success"] == 0
|
||||
|
||||
|
||||
class TestSiteGenerateCode:
|
||||
"""Integration tests for Site.generate_code static method."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
|
||||
"""Automatically rollback session changes after each test."""
|
||||
yield
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
def test_generate_code_returns_string_of_correct_length(self, db_session_with_containers: Session) -> None:
|
||||
"""Site.generate_code returns a code string of the requested length."""
|
||||
code = Site.generate_code(8)
|
||||
|
||||
assert isinstance(code, str)
|
||||
assert len(code) == 8
|
||||
|
||||
def test_generate_code_avoids_duplicates(self, db_session_with_containers: Session) -> None:
|
||||
"""Site.generate_code returns a code not already in use."""
|
||||
tenant_id = str(uuid4())
|
||||
app = App(
|
||||
tenant_id=tenant_id,
|
||||
name="Test App",
|
||||
mode=AppMode.CHAT,
|
||||
enable_site=True,
|
||||
enable_api=False,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=str(uuid4()),
|
||||
updated_by=str(uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
site = Site(
|
||||
app_id=app.id,
|
||||
title="Test Site",
|
||||
default_language="en-US",
|
||||
customize_token_strategy="not_allow",
|
||||
)
|
||||
# Set an explicit code so generate_code must avoid it
|
||||
site.code = "AAAAAAAA"
|
||||
db_session_with_containers.add(site)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
code = Site.generate_code(8)
|
||||
|
||||
assert isinstance(code, str)
|
||||
assert len(code) == 8
|
||||
assert code != site.code
|
||||
@ -6,7 +6,7 @@ import pytest
|
||||
import sqlalchemy as sa
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import insert, select
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
|
||||
from sqlalchemy.sql.sqltypes import VARCHAR
|
||||
@ -137,12 +137,12 @@ class TestEnumText:
|
||||
session.commit()
|
||||
|
||||
with Session(engine_with_containers) as session:
|
||||
user = session.query(_User).where(_User.id == admin_user_id).first()
|
||||
user = session.scalar(select(_User).where(_User.id == admin_user_id).limit(1))
|
||||
assert user.user_type == _UserType.admin
|
||||
assert user.user_type_nullable is None
|
||||
|
||||
with Session(engine_with_containers) as session:
|
||||
user = session.query(_User).where(_User.id == normal_user_id).first()
|
||||
user = session.scalar(select(_User).where(_User.id == normal_user_id).limit(1))
|
||||
assert user.user_type == _UserType.normal
|
||||
assert user.user_type_nullable == _UserType.normal
|
||||
|
||||
@ -206,7 +206,7 @@ class TestEnumText:
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
with Session(engine_with_containers) as session:
|
||||
_user = session.query(_User).where(_User.id == 1).first()
|
||||
_user = session.scalar(select(_User).where(_User.id == 1).limit(1))
|
||||
|
||||
assert str(exc.value) == "'invalid' is not a valid _UserType"
|
||||
|
||||
@ -222,7 +222,7 @@ class TestEnumText:
|
||||
session.commit()
|
||||
|
||||
with Session(engine_with_containers) as session:
|
||||
records = session.query(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id).all()
|
||||
records = session.scalars(select(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id)).all()
|
||||
|
||||
assert [record.model_type for record in records] == [
|
||||
ModelType.LLM,
|
||||
|
||||
@ -0,0 +1,170 @@
|
||||
"""
|
||||
Integration tests for WorkflowNodeExecutionModel.created_by_account and .created_by_end_user.
|
||||
|
||||
Migrated from unit_tests/models/test_workflow_trigger_log.py, replacing
|
||||
monkeypatch.setattr(db.session, "scalar", ...) with real Account/EndUser rows
|
||||
persisted in PostgreSQL so the db.session.get() call executes against the DB.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionModelCreatedBy:
|
||||
"""Integration tests for WorkflowNodeExecutionModel creator lookup properties."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
|
||||
"""Automatically rollback session changes after each test."""
|
||||
yield
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
def _create_account(self, db_session: Session) -> Account:
|
||||
account = Account(
|
||||
name="Test Account",
|
||||
email=f"test_{uuid4()}@example.com",
|
||||
password="hashed-password",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
db_session.add(account)
|
||||
db_session.flush()
|
||||
return account
|
||||
|
||||
def _create_end_user(self, db_session: Session, tenant_id: str, app_id: str) -> EndUser:
|
||||
end_user = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type="service_api",
|
||||
external_user_id=f"ext-{uuid4()}",
|
||||
name="End User",
|
||||
session_id=f"session-{uuid4()}",
|
||||
)
|
||||
end_user.is_anonymous = False
|
||||
db_session.add(end_user)
|
||||
db_session.flush()
|
||||
return end_user
|
||||
|
||||
def _create_app(self, db_session: Session, tenant_id: str, created_by: str) -> App:
|
||||
app = App(
|
||||
tenant_id=tenant_id,
|
||||
name=f"App {uuid4()}",
|
||||
mode=AppMode.WORKFLOW,
|
||||
enable_site=False,
|
||||
enable_api=True,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=created_by,
|
||||
updated_by=created_by,
|
||||
)
|
||||
db_session.add(app)
|
||||
db_session.flush()
|
||||
return app
|
||||
|
||||
def _make_execution(
|
||||
self, tenant_id: str, app_id: str, created_by_role: str, created_by: str
|
||||
) -> WorkflowNodeExecutionModel:
|
||||
return WorkflowNodeExecutionModel(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=None,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=None,
|
||||
node_id="n1",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status="succeeded",
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
execution_metadata=None,
|
||||
created_by_role=created_by_role,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
def test_created_by_account_returns_account_when_role_is_account(self, db_session_with_containers: Session) -> None:
|
||||
"""created_by_account returns the Account row when role is ACCOUNT."""
|
||||
account = self._create_account(db_session_with_containers)
|
||||
app = self._create_app(db_session_with_containers, str(uuid4()), account.id)
|
||||
|
||||
execution = self._make_execution(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
result = execution.created_by_account
|
||||
|
||||
assert result is not None
|
||||
assert result.id == account.id
|
||||
|
||||
def test_created_by_account_returns_none_when_role_is_end_user(self, db_session_with_containers: Session) -> None:
|
||||
"""created_by_account returns None when role is END_USER, even if an Account exists."""
|
||||
account = self._create_account(db_session_with_containers)
|
||||
app = self._create_app(db_session_with_containers, str(uuid4()), account.id)
|
||||
|
||||
execution = self._make_execution(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
result = execution.created_by_account
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_created_by_end_user_returns_end_user_when_role_is_end_user(
|
||||
self, db_session_with_containers: Session
|
||||
) -> None:
|
||||
"""created_by_end_user returns the EndUser row when role is END_USER."""
|
||||
account = self._create_account(db_session_with_containers)
|
||||
tenant_id = str(uuid4())
|
||||
app = self._create_app(db_session_with_containers, tenant_id, account.id)
|
||||
end_user = self._create_end_user(db_session_with_containers, tenant_id, app.id)
|
||||
|
||||
execution = self._make_execution(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app.id,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
created_by=end_user.id,
|
||||
)
|
||||
|
||||
result = execution.created_by_end_user
|
||||
|
||||
assert result is not None
|
||||
assert result.id == end_user.id
|
||||
|
||||
def test_created_by_end_user_returns_none_when_role_is_account(self, db_session_with_containers: Session) -> None:
|
||||
"""created_by_end_user returns None when role is ACCOUNT, even if an EndUser exists."""
|
||||
account = self._create_account(db_session_with_containers)
|
||||
tenant_id = str(uuid4())
|
||||
app = self._create_app(db_session_with_containers, tenant_id, account.id)
|
||||
end_user = self._create_end_user(db_session_with_containers, tenant_id, app.id)
|
||||
|
||||
execution = self._make_execution(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app.id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
created_by=end_user.id,
|
||||
)
|
||||
|
||||
result = execution.created_by_end_user
|
||||
|
||||
assert result is None
|
||||
@ -0,0 +1,395 @@
|
||||
"""Testcontainers integration tests for SQLAlchemyWorkflowNodeExecutionRepository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import OrderConfig
|
||||
from models.account import Account, Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
def _create_account_with_tenant(session: Session) -> Account:
|
||||
tenant = Tenant(name="Test Workspace")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
account = Account(name="test", email=f"test-{uuid4()}@example.com")
|
||||
session.add(account)
|
||||
session.flush()
|
||||
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
|
||||
def _make_repo(session: Session, account: Account, app_id: str) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
engine = session.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=sessionmaker(bind=engine, expire_on_commit=False),
|
||||
user=account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def _create_node_execution_model(
|
||||
session: Session,
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
workflow_run_id: str,
|
||||
index: int = 1,
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING,
|
||||
) -> WorkflowNodeExecutionModel:
|
||||
model = WorkflowNodeExecutionModel(
|
||||
id=str(uuid4()),
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=workflow_run_id,
|
||||
index=index,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=str(uuid4()),
|
||||
node_id=f"node-{index}",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title=f"Test Node {index}",
|
||||
inputs='{"input_key": "input_value"}',
|
||||
process_data='{"process_key": "process_value"}',
|
||||
outputs='{"output_key": "output_value"}',
|
||||
status=status,
|
||||
error=None,
|
||||
elapsed_time=1.5,
|
||||
execution_metadata="{}",
|
||||
created_at=datetime.now(),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
finished_at=None,
|
||||
)
|
||||
session.add(model)
|
||||
session.flush()
|
||||
return model
|
||||
|
||||
|
||||
class TestSave:
|
||||
def test_save_new_record(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
app_id = str(uuid4())
|
||||
repo = _make_repo(db_session_with_containers, account, app_id)
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_execution_id=str(uuid4()),
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_id="node-1",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Test Node",
|
||||
inputs={"input_key": "input_value"},
|
||||
process_data={"process_key": "process_value"},
|
||||
outputs={"result": "success"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
error=None,
|
||||
elapsed_time=1.5,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100},
|
||||
created_at=datetime.now(),
|
||||
finished_at=None,
|
||||
)
|
||||
|
||||
repo.save(execution)
|
||||
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session:
|
||||
saved = verify_session.get(WorkflowNodeExecutionModel, execution.id)
|
||||
assert saved is not None
|
||||
assert saved.tenant_id == account.current_tenant_id
|
||||
assert saved.app_id == app_id
|
||||
assert saved.node_id == "node-1"
|
||||
assert saved.status == WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
def test_save_updates_existing_record(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
repo = _make_repo(db_session_with_containers, account, str(uuid4()))
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_execution_id=str(uuid4()),
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_id="node-1",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Test Node",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
metadata=None,
|
||||
created_at=datetime.now(),
|
||||
finished_at=None,
|
||||
)
|
||||
|
||||
repo.save(execution)
|
||||
|
||||
execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
execution.elapsed_time = 2.5
|
||||
repo.save(execution)
|
||||
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session:
|
||||
saved = verify_session.get(WorkflowNodeExecutionModel, execution.id)
|
||||
assert saved is not None
|
||||
assert saved.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert saved.elapsed_time == 2.5
|
||||
|
||||
|
||||
class TestGetByWorkflowExecution:
|
||||
def test_returns_executions_ordered(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
tenant_id = account.current_tenant_id
|
||||
app_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
repo = _make_repo(db_session_with_containers, account, app_id)
|
||||
|
||||
_create_node_execution_model(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
)
|
||||
_create_node_execution_model(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
index=2,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
result = repo.get_by_workflow_execution(
|
||||
workflow_execution_id=workflow_run_id,
|
||||
order_config=order_config,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].index == 2
|
||||
assert result[1].index == 1
|
||||
assert all(isinstance(r, WorkflowNodeExecution) for r in result)
|
||||
|
||||
def test_excludes_paused_executions(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
tenant_id = account.current_tenant_id
|
||||
app_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
repo = _make_repo(db_session_with_containers, account, app_id)
|
||||
|
||||
_create_node_execution_model(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
)
|
||||
_create_node_execution_model(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
index=2,
|
||||
status=WorkflowNodeExecutionStatus.PAUSED,
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = repo.get_by_workflow_execution(workflow_execution_id=workflow_run_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].index == 1
|
||||
|
||||
|
||||
class TestToDbModel:
|
||||
def test_converts_domain_to_db_model(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
app_id = str(uuid4())
|
||||
repo = _make_repo(db_session_with_containers, account, app_id)
|
||||
|
||||
domain_model = WorkflowNodeExecution(
|
||||
id="test-id",
|
||||
workflow_id="test-workflow-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
workflow_execution_id="test-workflow-run-id",
|
||||
index=1,
|
||||
predecessor_node_id="test-predecessor-id",
|
||||
node_id="test-node-id",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Test Node",
|
||||
inputs={"input_key": "input_value"},
|
||||
process_data={"process_key": "process_value"},
|
||||
outputs={"output_key": "output_value"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
error=None,
|
||||
elapsed_time=1.5,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: Decimal("0.0"),
|
||||
},
|
||||
created_at=datetime.now(),
|
||||
finished_at=None,
|
||||
)
|
||||
|
||||
db_model = repo._to_db_model(domain_model)
|
||||
|
||||
assert isinstance(db_model, WorkflowNodeExecutionModel)
|
||||
assert db_model.id == domain_model.id
|
||||
assert db_model.tenant_id == account.current_tenant_id
|
||||
assert db_model.app_id == app_id
|
||||
assert db_model.workflow_id == domain_model.workflow_id
|
||||
assert db_model.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
assert db_model.workflow_run_id == domain_model.workflow_execution_id
|
||||
assert db_model.index == domain_model.index
|
||||
assert db_model.predecessor_node_id == domain_model.predecessor_node_id
|
||||
assert db_model.node_execution_id == domain_model.node_execution_id
|
||||
assert db_model.node_id == domain_model.node_id
|
||||
assert db_model.node_type == domain_model.node_type
|
||||
assert db_model.title == domain_model.title
|
||||
assert db_model.inputs_dict == domain_model.inputs
|
||||
assert db_model.process_data_dict == domain_model.process_data
|
||||
assert db_model.outputs_dict == domain_model.outputs
|
||||
assert db_model.execution_metadata_dict == jsonable_encoder(domain_model.metadata)
|
||||
assert db_model.status == domain_model.status
|
||||
assert db_model.error == domain_model.error
|
||||
assert db_model.elapsed_time == domain_model.elapsed_time
|
||||
assert db_model.created_at == domain_model.created_at
|
||||
assert db_model.created_by_role == CreatorUserRole.ACCOUNT
|
||||
assert db_model.created_by == account.id
|
||||
assert db_model.finished_at == domain_model.finished_at
|
||||
|
||||
|
||||
class TestToDomainModel:
|
||||
def test_converts_db_to_domain_model(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
app_id = str(uuid4())
|
||||
repo = _make_repo(db_session_with_containers, account, app_id)
|
||||
|
||||
inputs_dict = {"input_key": "input_value"}
|
||||
process_data_dict = {"process_key": "process_value"}
|
||||
outputs_dict = {"output_key": "output_value"}
|
||||
metadata_dict = {str(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS): 100}
|
||||
now = datetime.now()
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "test-id"
|
||||
db_model.tenant_id = account.current_tenant_id
|
||||
db_model.app_id = app_id
|
||||
db_model.workflow_id = "test-workflow-id"
|
||||
db_model.triggered_from = "workflow-run"
|
||||
db_model.workflow_run_id = "test-workflow-run-id"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = "test-predecessor-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = BuiltinNodeTypes.START
|
||||
db_model.title = "Test Node"
|
||||
db_model.inputs = json.dumps(inputs_dict)
|
||||
db_model.process_data = json.dumps(process_data_dict)
|
||||
db_model.outputs = json.dumps(outputs_dict)
|
||||
db_model.status = WorkflowNodeExecutionStatus.RUNNING
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.5
|
||||
db_model.execution_metadata = json.dumps(metadata_dict)
|
||||
db_model.created_at = now
|
||||
db_model.created_by_role = "account"
|
||||
db_model.created_by = account.id
|
||||
db_model.finished_at = None
|
||||
|
||||
domain_model = repo._to_domain_model(db_model)
|
||||
|
||||
assert isinstance(domain_model, WorkflowNodeExecution)
|
||||
assert domain_model.id == "test-id"
|
||||
assert domain_model.workflow_id == "test-workflow-id"
|
||||
assert domain_model.workflow_execution_id == "test-workflow-run-id"
|
||||
assert domain_model.index == 1
|
||||
assert domain_model.predecessor_node_id == "test-predecessor-id"
|
||||
assert domain_model.node_execution_id == "test-node-execution-id"
|
||||
assert domain_model.node_id == "test-node-id"
|
||||
assert domain_model.node_type == BuiltinNodeTypes.START
|
||||
assert domain_model.title == "Test Node"
|
||||
assert domain_model.inputs == inputs_dict
|
||||
assert domain_model.process_data == process_data_dict
|
||||
assert domain_model.outputs == outputs_dict
|
||||
assert domain_model.status == WorkflowNodeExecutionStatus.RUNNING
|
||||
assert domain_model.error is None
|
||||
assert domain_model.elapsed_time == 1.5
|
||||
assert domain_model.metadata == {WorkflowNodeExecutionMetadataKey(k): v for k, v in metadata_dict.items()}
|
||||
assert domain_model.created_at == now
|
||||
assert domain_model.finished_at is None
|
||||
|
||||
def test_domain_model_without_offload_data(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
repo = _make_repo(db_session_with_containers, account, str(uuid4()))
|
||||
|
||||
process_data = {"normal": "data"}
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = str(uuid4())
|
||||
db_model.tenant_id = account.current_tenant_id
|
||||
db_model.app_id = str(uuid4())
|
||||
db_model.workflow_id = str(uuid4())
|
||||
db_model.triggered_from = "workflow-run"
|
||||
db_model.workflow_run_id = None
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_execution_id = str(uuid4())
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = "llm"
|
||||
db_model.title = "Test Node"
|
||||
db_model.inputs = None
|
||||
db_model.process_data = json.dumps(process_data)
|
||||
db_model.outputs = None
|
||||
db_model.status = "succeeded"
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.5
|
||||
db_model.execution_metadata = "{}"
|
||||
db_model.created_at = datetime.now()
|
||||
db_model.created_by_role = "account"
|
||||
db_model.created_by = account.id
|
||||
db_model.finished_at = None
|
||||
|
||||
domain_model = repo._to_domain_model(db_model)
|
||||
|
||||
assert domain_model.process_data == process_data
|
||||
assert domain_model.process_data_truncated is False
|
||||
assert domain_model.get_truncated_process_data() is None
|
||||
@ -0,0 +1,255 @@
|
||||
"""
|
||||
Integration tests for RagPipelineService methods that interact with the database.
|
||||
|
||||
Migrated from unit_tests/services/rag_pipeline/test_rag_pipeline_service.py, replacing
|
||||
db.session.scalar/commit/delete mocker patches with real PostgreSQL operations.
|
||||
|
||||
Covers:
|
||||
- get_pipeline: Dataset and Pipeline lookups
|
||||
- update_customized_pipeline_template: find + unique-name check + commit
|
||||
- delete_customized_pipeline_template: find + delete + commit
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate
|
||||
from models.enums import DataSourceType
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
class TestRagPipelineServiceGetPipeline:
|
||||
"""Integration tests for RagPipelineService.get_pipeline."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
|
||||
yield
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
def _make_service(self, flask_app_with_containers) -> RagPipelineService:
|
||||
with (
|
||||
patch(
|
||||
"services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
session_factory = sessionmaker(bind=flask_app_with_containers.extensions["sqlalchemy"].engine)
|
||||
return RagPipelineService(session_maker=session_factory)
|
||||
|
||||
def _create_pipeline(self, db_session: Session, tenant_id: str, created_by: str) -> Pipeline:
|
||||
pipeline = Pipeline(
|
||||
tenant_id=tenant_id,
|
||||
name=f"Pipeline {uuid4()}",
|
||||
description="",
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session.add(pipeline)
|
||||
db_session.flush()
|
||||
return pipeline
|
||||
|
||||
def _create_dataset(
|
||||
self, db_session: Session, tenant_id: str, created_by: str, pipeline_id: str | None = None
|
||||
) -> Dataset:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=f"Dataset {uuid4()}",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by,
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
db_session.add(dataset)
|
||||
db_session.flush()
|
||||
return dataset
|
||||
|
||||
def test_get_pipeline_raises_when_dataset_not_found(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
) -> None:
|
||||
"""get_pipeline raises ValueError when dataset does not exist."""
|
||||
service = self._make_service(flask_app_with_containers)
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
service.get_pipeline(tenant_id=str(uuid4()), dataset_id=str(uuid4()))
|
||||
|
||||
def test_get_pipeline_raises_when_pipeline_not_found(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
) -> None:
|
||||
"""get_pipeline raises ValueError when dataset exists but has no linked pipeline."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
dataset = self._create_dataset(db_session_with_containers, tenant_id, created_by, pipeline_id=None)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
service = self._make_service(flask_app_with_containers)
|
||||
|
||||
with pytest.raises(ValueError, match="(Dataset not found|Pipeline not found)"):
|
||||
service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id)
|
||||
|
||||
def test_get_pipeline_returns_pipeline_when_found(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
) -> None:
|
||||
"""get_pipeline returns the Pipeline when both Dataset and Pipeline exist."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
|
||||
pipeline = self._create_pipeline(db_session_with_containers, tenant_id, created_by)
|
||||
dataset = self._create_dataset(db_session_with_containers, tenant_id, created_by, pipeline_id=pipeline.id)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
service = self._make_service(flask_app_with_containers)
|
||||
|
||||
result = service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id)
|
||||
|
||||
assert result.id == pipeline.id
|
||||
|
||||
|
||||
class TestUpdateCustomizedPipelineTemplate:
|
||||
"""Integration tests for RagPipelineService.update_customized_pipeline_template."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
|
||||
yield
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
def _create_template(
|
||||
self, db_session: Session, tenant_id: str, created_by: str, name: str = "Template"
|
||||
) -> PipelineCustomizedTemplate:
|
||||
template = PipelineCustomizedTemplate(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description="Original description",
|
||||
chunk_structure="fixed_size",
|
||||
icon={"type": "emoji", "value": "📄"},
|
||||
position=1,
|
||||
yaml_content="{}",
|
||||
install_count=0,
|
||||
language="en-US",
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session.add(template)
|
||||
db_session.flush()
|
||||
return template
|
||||
|
||||
def test_update_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None:
|
||||
"""update_customized_pipeline_template updates name and description."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
template = self._create_template(db_session_with_containers, tenant_id, created_by)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id)
|
||||
|
||||
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
|
||||
info = PipelineTemplateInfoEntity(
|
||||
name="Updated Name",
|
||||
description="Updated description",
|
||||
icon_info=IconInfo(icon="🔥"),
|
||||
)
|
||||
result = RagPipelineService.update_customized_pipeline_template(template.id, info)
|
||||
|
||||
assert result.name == "Updated Name"
|
||||
assert result.description == "Updated description"
|
||||
|
||||
def test_update_template_raises_when_not_found(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
) -> None:
|
||||
"""update_customized_pipeline_template raises ValueError when template doesn't exist."""
|
||||
fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4()))
|
||||
|
||||
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
|
||||
info = PipelineTemplateInfoEntity(
|
||||
name="New Name",
|
||||
description="desc",
|
||||
icon_info=IconInfo(icon="📄"),
|
||||
)
|
||||
with pytest.raises(ValueError, match="Customized pipeline template not found"):
|
||||
RagPipelineService.update_customized_pipeline_template(str(uuid4()), info)
|
||||
|
||||
def test_update_template_raises_on_duplicate_name(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
) -> None:
|
||||
"""update_customized_pipeline_template raises ValueError when new name already exists."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
template1 = self._create_template(db_session_with_containers, tenant_id, created_by, name="Original")
|
||||
self._create_template(db_session_with_containers, tenant_id, created_by, name="Duplicate")
|
||||
db_session_with_containers.flush()
|
||||
|
||||
fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id)
|
||||
|
||||
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
|
||||
info = PipelineTemplateInfoEntity(
|
||||
name="Duplicate",
|
||||
description="desc",
|
||||
icon_info=IconInfo(icon="📄"),
|
||||
)
|
||||
with pytest.raises(ValueError, match="Template name is already exists"):
|
||||
RagPipelineService.update_customized_pipeline_template(template1.id, info)
|
||||
|
||||
|
||||
class TestDeleteCustomizedPipelineTemplate:
|
||||
"""Integration tests for RagPipelineService.delete_customized_pipeline_template."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
|
||||
yield
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
def _create_template(self, db_session: Session, tenant_id: str, created_by: str) -> PipelineCustomizedTemplate:
|
||||
template = PipelineCustomizedTemplate(
|
||||
tenant_id=tenant_id,
|
||||
name=f"Template {uuid4()}",
|
||||
description="Description",
|
||||
chunk_structure="fixed_size",
|
||||
icon={"type": "emoji", "value": "📄"},
|
||||
position=1,
|
||||
yaml_content="{}",
|
||||
install_count=0,
|
||||
language="en-US",
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session.add(template)
|
||||
db_session.flush()
|
||||
return template
|
||||
|
||||
def test_delete_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None:
|
||||
"""delete_customized_pipeline_template removes the template from the DB."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
template = self._create_template(db_session_with_containers, tenant_id, created_by)
|
||||
template_id = template.id
|
||||
db_session_with_containers.flush()
|
||||
|
||||
fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id)
|
||||
|
||||
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
|
||||
# Verify the record is deleted within the same context
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db as ext_db
|
||||
|
||||
remaining = ext_db.session.scalar(
|
||||
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id)
|
||||
)
|
||||
assert remaining is None
|
||||
|
||||
def test_delete_template_raises_when_not_found(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
) -> None:
|
||||
"""delete_customized_pipeline_template raises ValueError when template doesn't exist."""
|
||||
fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4()))
|
||||
|
||||
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
|
||||
with pytest.raises(ValueError, match="Customized pipeline template not found"):
|
||||
RagPipelineService.delete_customized_pipeline_template(str(uuid4()))
|
||||
@ -26,7 +26,7 @@ from datetime import timedelta
|
||||
import pytest
|
||||
from graphon.entities import WorkflowExecution
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
@ -679,9 +679,12 @@ class TestWorkflowPauseIntegration:
|
||||
|
||||
# Verify only 3 were deleted
|
||||
remaining_count = (
|
||||
self.session.query(WorkflowPauseModel)
|
||||
.filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]))
|
||||
.count()
|
||||
self.session.scalar(
|
||||
select(func.count(WorkflowPauseModel.id)).where(
|
||||
WorkflowPauseModel.id.in_([pe.id for pe in pause_entities])
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
assert remaining_count == 2
|
||||
|
||||
|
||||
@ -258,10 +258,10 @@ class TestParentChildIndexProcessor:
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
segment_query = Mock()
|
||||
segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [SimpleNamespace(id="seg-1")]
|
||||
session = Mock()
|
||||
session.query.return_value = segment_query
|
||||
session.scalars.return_value = scalars_result
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = False
|
||||
|
||||
@ -220,10 +220,10 @@ class TestQAIndexProcessor:
|
||||
self, processor: QAIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
mock_segment = SimpleNamespace(id="seg-1")
|
||||
mock_query = Mock()
|
||||
mock_query.filter.return_value.all.return_value = [mock_segment]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [mock_segment]
|
||||
mock_session = Mock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_session.scalars.return_value = scalars_result
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = mock_session
|
||||
session_context.__exit__.return_value = False
|
||||
|
||||
@ -8,7 +8,6 @@ import pytest
|
||||
from flask import Flask, current_app
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from sqlalchemy import column
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@ -4039,21 +4038,9 @@ class TestDatasetRetrievalAdditionalHelpers:
|
||||
|
||||
def test_get_available_datasets(self, retrieval: DatasetRetrieval) -> None:
|
||||
session = Mock()
|
||||
subquery_query = Mock()
|
||||
subquery_query.where.return_value = subquery_query
|
||||
subquery_query.group_by.return_value = subquery_query
|
||||
subquery_query.having.return_value = subquery_query
|
||||
subquery_query.subquery.return_value = SimpleNamespace(
|
||||
c=SimpleNamespace(
|
||||
dataset_id=column("dataset_id"), available_document_count=column("available_document_count")
|
||||
)
|
||||
)
|
||||
|
||||
dataset_query = Mock()
|
||||
dataset_query.outerjoin.return_value = dataset_query
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")]
|
||||
session.query.side_effect = [subquery_query, dataset_query]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")]
|
||||
session.scalars.return_value = scalars_result
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
@ -4902,9 +4889,6 @@ class TestInternalHooksCoverage:
|
||||
_scalars(segments),
|
||||
_scalars(bindings),
|
||||
]
|
||||
query = Mock()
|
||||
query.where.return_value = query
|
||||
session.query.return_value = query
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = False
|
||||
@ -4919,7 +4903,7 @@ class TestInternalHooksCoverage:
|
||||
):
|
||||
retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1})
|
||||
|
||||
query.update.assert_called_once()
|
||||
session.execute.assert_called_once()
|
||||
mock_trace.assert_called_once()
|
||||
|
||||
def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None:
|
||||
|
||||
@ -260,6 +260,28 @@ def test_agent_invoke_engine_meta_error():
|
||||
assert error_meta.error == "meta failure"
|
||||
|
||||
|
||||
def test_convert_tool_response_excludes_variable_messages():
|
||||
"""Regression test for issue #34723.
|
||||
|
||||
WorkflowTool._invoke yields VARIABLE, TEXT, and suppressed-JSON messages.
|
||||
_convert_tool_response_to_str must skip VARIABLE messages so that the
|
||||
returned string contains only the TEXT representation and not a
|
||||
duplicated, garbled Pydantic repr of the same data.
|
||||
"""
|
||||
tool = _build_tool()
|
||||
outputs = {"reports": "hello"}
|
||||
messages = [
|
||||
tool.create_variable_message(variable_name="reports", variable_value="hello"),
|
||||
tool.create_text_message('{"reports": "hello"}'),
|
||||
tool.create_json_message(outputs, suppress_output=True),
|
||||
]
|
||||
|
||||
result = ToolEngine._convert_tool_response_to_str(messages)
|
||||
|
||||
assert result == '{"reports": "hello"}'
|
||||
assert "variable_name" not in result
|
||||
|
||||
|
||||
def test_agent_invoke_tool_invoke_error():
|
||||
tool = _build_tool(with_llm_parameter=True)
|
||||
callback = Mock()
|
||||
|
||||
@ -637,7 +637,7 @@ def test_list_default_builtin_providers_for_postgres_and_mysql():
|
||||
for scheme in ("postgresql", "mysql"):
|
||||
session = Mock()
|
||||
session.execute.return_value.all.return_value = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")]
|
||||
session.query.return_value.where.return_value.all.return_value = provider_records
|
||||
session.scalars.return_value = iter(provider_records)
|
||||
|
||||
with patch("core.tools.tool_manager.dify_config", SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME=scheme)):
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
|
||||
@ -291,24 +291,6 @@ class TestAppModelConfig:
|
||||
# Assert
|
||||
assert result == questions
|
||||
|
||||
def test_app_model_config_annotation_reply_dict_disabled(self):
|
||||
"""Test annotation_reply_dict when annotation is disabled."""
|
||||
# Arrange
|
||||
config = AppModelConfig(
|
||||
app_id=str(uuid4()),
|
||||
provider="openai",
|
||||
model_id="gpt-4",
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Mock database scalar to return None (no annotation setting found)
|
||||
with patch("models.model.db.session.scalar", return_value=None):
|
||||
# Act
|
||||
result = config.annotation_reply_dict
|
||||
|
||||
# Assert
|
||||
assert result == {"enabled": False}
|
||||
|
||||
|
||||
class TestConversationModel:
|
||||
"""Test suite for Conversation model integrity."""
|
||||
@ -948,17 +930,6 @@ class TestSiteModel:
|
||||
with pytest.raises(ValueError, match="Custom disclaimer cannot exceed 512 characters"):
|
||||
site.custom_disclaimer = long_disclaimer
|
||||
|
||||
def test_site_generate_code(self):
|
||||
"""Test Site.generate_code static method."""
|
||||
# Mock database scalar to return 0 (no existing codes)
|
||||
with patch("models.model.db.session.scalar", return_value=0):
|
||||
# Act
|
||||
code = Site.generate_code(8)
|
||||
|
||||
# Assert
|
||||
assert isinstance(code, str)
|
||||
assert len(code) == 8
|
||||
|
||||
|
||||
class TestModelIntegration:
|
||||
"""Test suite for model integration scenarios."""
|
||||
@ -1146,314 +1117,3 @@ class TestModelIntegration:
|
||||
# Assert
|
||||
assert site.app_id == app.id
|
||||
assert app.enable_site is True
|
||||
|
||||
|
||||
class TestConversationStatusCount:
|
||||
"""Test suite for Conversation.status_count property N+1 query fix."""
|
||||
|
||||
def test_status_count_no_messages(self):
|
||||
"""Test status_count returns None when conversation has no messages."""
|
||||
# Arrange
|
||||
conversation = Conversation(
|
||||
app_id=str(uuid4()),
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = str(uuid4())
|
||||
|
||||
# Mock the database query to return no messages
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
mock_scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_status_count_messages_without_workflow_runs(self):
|
||||
"""Test status_count when messages have no workflow_run_id."""
|
||||
# Arrange
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock the database query to return no messages with workflow_run_id
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
mock_scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_status_count_batch_loading_implementation(self):
|
||||
"""Test that status_count uses batch loading instead of N+1 queries."""
|
||||
# Arrange
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
|
||||
# Create workflow run IDs
|
||||
workflow_run_id_1 = str(uuid4())
|
||||
workflow_run_id_2 = str(uuid4())
|
||||
workflow_run_id_3 = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock messages with workflow_run_id
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id_1,
|
||||
),
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id_2,
|
||||
),
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id_3,
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow runs with different statuses
|
||||
mock_workflow_runs = [
|
||||
MagicMock(
|
||||
id=workflow_run_id_1,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
MagicMock(
|
||||
id=workflow_run_id_2,
|
||||
status=WorkflowExecutionStatus.FAILED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
MagicMock(
|
||||
id=workflow_run_id_3,
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
]
|
||||
|
||||
# Track database calls
|
||||
calls_made = []
|
||||
|
||||
def mock_scalars(query):
|
||||
calls_made.append(str(query))
|
||||
mock_result = MagicMock()
|
||||
|
||||
# Return messages for the first query (messages with workflow_run_id)
|
||||
if "messages" in str(query) and "conversation_id" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
# Return workflow runs for the batch query
|
||||
elif "workflow_runs" in str(query):
|
||||
mock_result.all.return_value = mock_workflow_runs
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
|
||||
return mock_result
|
||||
|
||||
# Act & Assert
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
|
||||
result = conversation.status_count
|
||||
|
||||
# Verify only 2 database queries were made (not N+1)
|
||||
assert len(calls_made) == 2, f"Expected 2 queries, got {len(calls_made)}: {calls_made}"
|
||||
|
||||
# Verify the first query gets messages
|
||||
assert "messages" in calls_made[0]
|
||||
assert "conversation_id" in calls_made[0]
|
||||
|
||||
# Verify the second query batch loads workflow runs with proper filtering
|
||||
assert "workflow_runs" in calls_made[1]
|
||||
assert "app_id" in calls_made[1] # Security filter applied
|
||||
assert "IN" in calls_made[1] # Batch loading with IN clause
|
||||
|
||||
# Verify correct status counts
|
||||
assert result["success"] == 1 # One SUCCEEDED
|
||||
assert result["failed"] == 1 # One FAILED
|
||||
assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED
|
||||
assert result["paused"] == 0
|
||||
|
||||
def test_status_count_app_id_filtering(self):
|
||||
"""Test that status_count filters workflow runs by app_id for security."""
|
||||
# Arrange
|
||||
app_id = str(uuid4())
|
||||
other_app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock message with workflow_run_id
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
),
|
||||
]
|
||||
|
||||
calls_made = []
|
||||
|
||||
def mock_scalars(query):
|
||||
calls_made.append(str(query))
|
||||
mock_result = MagicMock()
|
||||
|
||||
if "messages" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
elif "workflow_runs" in str(query):
|
||||
# Return empty list because no workflow run matches the correct app_id
|
||||
mock_result.all.return_value = [] # Workflow run filtered out by app_id
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
|
||||
return mock_result
|
||||
|
||||
# Act
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert - query should include app_id filter
|
||||
workflow_query = calls_made[1]
|
||||
assert "app_id" in workflow_query
|
||||
|
||||
# Since workflow run has wrong app_id, it shouldn't be included in counts
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["partial_success"] == 0
|
||||
assert result["paused"] == 0
|
||||
|
||||
def test_status_count_handles_invalid_workflow_status(self):
|
||||
"""Test that status_count gracefully handles invalid workflow status values."""
|
||||
# Arrange
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow run with invalid status
|
||||
mock_workflow_runs = [
|
||||
MagicMock(
|
||||
id=workflow_run_id,
|
||||
status="invalid_status", # Invalid status that should raise ValueError
|
||||
app_id=app_id,
|
||||
),
|
||||
]
|
||||
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
# Mock the messages query
|
||||
def mock_scalars_side_effect(query):
|
||||
mock_result = MagicMock()
|
||||
if "messages" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
elif "workflow_runs" in str(query):
|
||||
mock_result.all.return_value = mock_workflow_runs
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
return mock_result
|
||||
|
||||
mock_scalars.side_effect = mock_scalars_side_effect
|
||||
|
||||
# Act - should not raise exception
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert - should handle invalid status gracefully
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["partial_success"] == 0
|
||||
assert result["paused"] == 0
|
||||
|
||||
def test_status_count_paused(self):
|
||||
"""Test status_count includes paused workflow runs."""
|
||||
# Arrange
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
),
|
||||
]
|
||||
|
||||
mock_workflow_runs = [
|
||||
MagicMock(
|
||||
id=workflow_run_id,
|
||||
status=WorkflowExecutionStatus.PAUSED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
]
|
||||
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
|
||||
def mock_scalars_side_effect(query):
|
||||
mock_result = MagicMock()
|
||||
if "messages" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
elif "workflow_runs" in str(query):
|
||||
mock_result.all.return_value = mock_workflow_runs
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
return mock_result
|
||||
|
||||
mock_scalars.side_effect = mock_scalars_side_effect
|
||||
|
||||
# Act
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert
|
||||
assert result["paused"] == 1
|
||||
|
||||
@ -101,118 +101,6 @@ def _build_local_file_mapping(record_id: str, *, tenant_id: str | None = None) -
|
||||
return mapping
|
||||
|
||||
|
||||
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
|
||||
def test_inputs_resolve_owner_tenant_for_single_file_mapping(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
owner_cls: type[Conversation] | type[Message],
|
||||
):
|
||||
model_module = importlib.import_module("models.model")
|
||||
build_calls: list[tuple[dict[str, object], str]] = []
|
||||
|
||||
monkeypatch.setattr(model_module.db.session, "scalar", lambda _: "tenant-from-app")
|
||||
|
||||
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller):
|
||||
_ = config, strict_type_validation, access_controller
|
||||
build_calls.append((dict(mapping), tenant_id))
|
||||
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
|
||||
|
||||
monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping)
|
||||
|
||||
owner = owner_cls(app_id="app-1")
|
||||
owner.inputs = {"file": _build_local_file_mapping("upload-1")}
|
||||
|
||||
restored_inputs = owner.inputs
|
||||
|
||||
assert restored_inputs["file"] == {"tenant_id": "tenant-from-app", "upload_file_id": "upload-1"}
|
||||
assert build_calls == [
|
||||
(
|
||||
{
|
||||
**_build_local_file_mapping("upload-1"),
|
||||
"upload_file_id": "upload-1",
|
||||
},
|
||||
"tenant-from-app",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
|
||||
def test_inputs_resolve_owner_tenant_for_file_list_mapping(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
owner_cls: type[Conversation] | type[Message],
|
||||
):
|
||||
model_module = importlib.import_module("models.model")
|
||||
build_calls: list[tuple[dict[str, object], str]] = []
|
||||
|
||||
monkeypatch.setattr(model_module.db.session, "scalar", lambda _: "tenant-from-app")
|
||||
|
||||
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller):
|
||||
_ = config, strict_type_validation, access_controller
|
||||
build_calls.append((dict(mapping), tenant_id))
|
||||
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
|
||||
|
||||
monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping)
|
||||
|
||||
owner = owner_cls(app_id="app-1")
|
||||
owner.inputs = {
|
||||
"files": [
|
||||
_build_local_file_mapping("upload-1"),
|
||||
_build_local_file_mapping("upload-2"),
|
||||
]
|
||||
}
|
||||
|
||||
restored_inputs = owner.inputs
|
||||
|
||||
assert restored_inputs["files"] == [
|
||||
{"tenant_id": "tenant-from-app", "upload_file_id": "upload-1"},
|
||||
{"tenant_id": "tenant-from-app", "upload_file_id": "upload-2"},
|
||||
]
|
||||
assert build_calls == [
|
||||
(
|
||||
{
|
||||
**_build_local_file_mapping("upload-1"),
|
||||
"upload_file_id": "upload-1",
|
||||
},
|
||||
"tenant-from-app",
|
||||
),
|
||||
(
|
||||
{
|
||||
**_build_local_file_mapping("upload-2"),
|
||||
"upload_file_id": "upload-2",
|
||||
},
|
||||
"tenant-from-app",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
|
||||
def test_inputs_prefer_serialized_tenant_id_when_present(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
owner_cls: type[Conversation] | type[Message],
|
||||
):
|
||||
model_module = importlib.import_module("models.model")
|
||||
|
||||
def fail_if_called(_):
|
||||
raise AssertionError("App tenant lookup should not run when tenant_id exists in the file mapping")
|
||||
|
||||
monkeypatch.setattr(model_module.db.session, "scalar", fail_if_called)
|
||||
|
||||
def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller):
|
||||
_ = config, strict_type_validation, access_controller
|
||||
return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")}
|
||||
|
||||
monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping)
|
||||
|
||||
owner = owner_cls(app_id="app-1")
|
||||
owner.inputs = {"file": _build_local_file_mapping("upload-1", tenant_id="tenant-from-payload")}
|
||||
|
||||
restored_inputs = owner.inputs
|
||||
|
||||
assert restored_inputs["file"] == {
|
||||
"tenant_id": "tenant-from-payload",
|
||||
"upload_file_id": "upload-1",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("owner_cls", [Conversation, Message])
|
||||
def test_inputs_restore_external_remote_url_file_mappings(owner_cls: type[Conversation] | type[Message]) -> None:
|
||||
owner = owner_cls(app_id="app-1")
|
||||
|
||||
@ -1,188 +0,0 @@
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from models.engine import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db_scalar(monkeypatch):
|
||||
"""Provide a controllable fake for db.session.scalar (SQLAlchemy 2.0 style)."""
|
||||
calls = []
|
||||
|
||||
def _install(side_effect):
|
||||
def _fake_scalar(statement):
|
||||
calls.append(statement)
|
||||
return side_effect(statement)
|
||||
|
||||
# Patch the modern API used by the model implementation
|
||||
monkeypatch.setattr(db.session, "scalar", _fake_scalar)
|
||||
|
||||
# Backward-compatibility: if the implementation still uses db.session.get,
|
||||
# make it delegate to the same side_effect so tests remain valid on older code.
|
||||
if hasattr(db.session, "get"):
|
||||
|
||||
def _fake_get(*_args, **_kwargs):
|
||||
return side_effect(None)
|
||||
|
||||
monkeypatch.setattr(db.session, "get", _fake_get)
|
||||
|
||||
return calls
|
||||
|
||||
return _install
|
||||
|
||||
|
||||
def make_account(id_: str = "acc-1"):
|
||||
# Use a simple object to avoid constructing a full SQLAlchemy model instance
|
||||
# Python 3.12 forbids reassigning __class__ for SimpleNamespace; not needed here.
|
||||
obj = types.SimpleNamespace()
|
||||
obj.id = id_
|
||||
return obj
|
||||
|
||||
|
||||
def make_end_user(id_: str = "user-1"):
|
||||
# Lightweight stand-in object; no need to spoof class identity.
|
||||
obj = types.SimpleNamespace()
|
||||
obj.id = id_
|
||||
return obj
|
||||
|
||||
|
||||
def test_created_by_account_returns_account_when_role_account(fake_db_scalar):
|
||||
account = make_account("acc-1")
|
||||
|
||||
# The implementation uses db.session.scalar(select(Account)...). We only need to
|
||||
# return the expected object when called; the exact SQL is irrelevant for this unit test.
|
||||
def side_effect(_statement):
|
||||
return account
|
||||
|
||||
fake_db_scalar(side_effect)
|
||||
|
||||
log = WorkflowNodeExecutionModel(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
workflow_id="w1",
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=None,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=None,
|
||||
node_id="n1",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status="succeeded",
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
execution_metadata=None,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
created_by="acc-1",
|
||||
)
|
||||
|
||||
assert log.created_by_account is account
|
||||
|
||||
|
||||
def test_created_by_account_returns_none_when_role_not_account(fake_db_scalar):
|
||||
# Even if an Account with matching id exists, property should return None when role is END_USER
|
||||
account = make_account("acc-1")
|
||||
|
||||
def side_effect(_statement):
|
||||
return account
|
||||
|
||||
fake_db_scalar(side_effect)
|
||||
|
||||
log = WorkflowNodeExecutionModel(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
workflow_id="w1",
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=None,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=None,
|
||||
node_id="n1",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status="succeeded",
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
execution_metadata=None,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
created_by="acc-1",
|
||||
)
|
||||
|
||||
assert log.created_by_account is None
|
||||
|
||||
|
||||
def test_created_by_end_user_returns_end_user_when_role_end_user(fake_db_scalar):
|
||||
end_user = make_end_user("user-1")
|
||||
|
||||
def side_effect(_statement):
|
||||
return end_user
|
||||
|
||||
fake_db_scalar(side_effect)
|
||||
|
||||
log = WorkflowNodeExecutionModel(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
workflow_id="w1",
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=None,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=None,
|
||||
node_id="n1",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status="succeeded",
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
execution_metadata=None,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
created_by="user-1",
|
||||
)
|
||||
|
||||
assert log.created_by_end_user is end_user
|
||||
|
||||
|
||||
def test_created_by_end_user_returns_none_when_role_not_end_user(fake_db_scalar):
|
||||
end_user = make_end_user("user-1")
|
||||
|
||||
def side_effect(_statement):
|
||||
return end_user
|
||||
|
||||
fake_db_scalar(side_effect)
|
||||
|
||||
log = WorkflowNodeExecutionModel(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
workflow_id="w1",
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=None,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=None,
|
||||
node_id="n1",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status="succeeded",
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
execution_metadata=None,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
created_by="user-1",
|
||||
)
|
||||
|
||||
assert log.created_by_end_user is None
|
||||
@ -1,3 +0,0 @@
|
||||
"""
|
||||
Unit tests for workflow_node_execution repositories.
|
||||
"""
|
||||
@ -1,340 +0,0 @@
|
||||
"""
|
||||
Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, PropertyMock
|
||||
|
||||
import pytest
|
||||
from graphon.entities import (
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import OrderConfig
|
||||
from models.account import Account, Tenant
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
def configure_mock_execution(mock_execution):
|
||||
"""Configure a mock execution with proper JSON serializable values."""
|
||||
# Configure inputs, outputs, process_data, and execution_metadata to return JSON serializable values
|
||||
type(mock_execution).inputs = PropertyMock(return_value='{"key": "value"}')
|
||||
type(mock_execution).outputs = PropertyMock(return_value='{"result": "success"}')
|
||||
type(mock_execution).process_data = PropertyMock(return_value='{"process": "data"}')
|
||||
type(mock_execution).execution_metadata = PropertyMock(return_value='{"metadata": "info"}')
|
||||
|
||||
# Configure status and triggered_from to be valid enum values
|
||||
mock_execution.status = "running"
|
||||
mock_execution.triggered_from = "workflow-run"
|
||||
|
||||
return mock_execution
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
"""Create a mock SQLAlchemy session."""
|
||||
session = MagicMock(spec=Session)
|
||||
# Configure the session to be used as a context manager
|
||||
session.__enter__ = MagicMock(return_value=session)
|
||||
session.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Configure the session factory to return the session
|
||||
session_factory = MagicMock(spec=sessionmaker)
|
||||
session_factory.return_value = session
|
||||
return session, session_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Create a user instance for testing."""
|
||||
user = Account(name="test", email="test@example.com")
|
||||
user.id = "test-user-id"
|
||||
|
||||
tenant = Tenant(name="Test Workspace")
|
||||
tenant.id = "test-tenant"
|
||||
user._current_tenant = MagicMock()
|
||||
user._current_tenant.id = "test-tenant"
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository(session, mock_user):
|
||||
"""Create a repository instance with test data."""
|
||||
_, session_factory = session
|
||||
app_id = "test-app"
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def test_save(repository, session):
|
||||
"""Test save method."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution
|
||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
execution.id = "test-id"
|
||||
execution.node_execution_id = "test-node-execution-id"
|
||||
execution.tenant_id = None
|
||||
execution.app_id = None
|
||||
execution.inputs = None
|
||||
execution.process_data = None
|
||||
execution.outputs = None
|
||||
execution.metadata = None
|
||||
execution.workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Mock the to_db_model method to return the execution itself
|
||||
# This simulates the behavior of setting tenant_id and app_id
|
||||
db_model = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
db_model.id = "test-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
repository._to_db_model = MagicMock(return_value=db_model)
|
||||
|
||||
# Mock session.get to return None (no existing record)
|
||||
session_obj.get.return_value = None
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert to_db_model was called with the execution
|
||||
repository._to_db_model.assert_called_once_with(execution)
|
||||
|
||||
# Assert session.get was called to check for existing record
|
||||
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, db_model.id)
|
||||
|
||||
# Assert session.add was called for new record
|
||||
session_obj.add.assert_called_once_with(db_model)
|
||||
|
||||
# Assert session.commit was called
|
||||
session_obj.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_save_with_existing_tenant_id(repository, session):
|
||||
"""Test save method with existing tenant_id."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution with existing tenant_id
|
||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
execution.id = "existing-id"
|
||||
execution.node_execution_id = "existing-node-execution-id"
|
||||
execution.tenant_id = "existing-tenant"
|
||||
execution.app_id = None
|
||||
execution.inputs = None
|
||||
execution.process_data = None
|
||||
execution.outputs = None
|
||||
execution.metadata = None
|
||||
|
||||
# Create a modified execution that will be returned by _to_db_model
|
||||
modified_execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
modified_execution.id = "existing-id"
|
||||
modified_execution.node_execution_id = "existing-node-execution-id"
|
||||
modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change
|
||||
modified_execution.app_id = repository._app_id # App ID should be set
|
||||
# Create a dictionary to simulate __dict__ for updating attributes
|
||||
modified_execution.__dict__ = {
|
||||
"id": "existing-id",
|
||||
"node_execution_id": "existing-node-execution-id",
|
||||
"tenant_id": "existing-tenant",
|
||||
"app_id": repository._app_id,
|
||||
}
|
||||
|
||||
# Mock the to_db_model method to return the modified execution
|
||||
repository._to_db_model = MagicMock(return_value=modified_execution)
|
||||
|
||||
# Mock session.get to return an existing record
|
||||
existing_model = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
session_obj.get.return_value = existing_model
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert to_db_model was called with the execution
|
||||
repository._to_db_model.assert_called_once_with(execution)
|
||||
|
||||
# Assert session.get was called to check for existing record
|
||||
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, modified_execution.id)
|
||||
|
||||
# Assert session.add was NOT called since we're updating existing
|
||||
session_obj.add.assert_not_called()
|
||||
|
||||
# Assert session.commit was called
|
||||
session_obj.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_by_workflow_execution(repository, session, mocker: MockerFixture):
|
||||
"""Test get_by_workflow_execution method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||
mock_asc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.asc")
|
||||
mock_desc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.desc")
|
||||
|
||||
mock_WorkflowNodeExecutionModel = mocker.patch(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel"
|
||||
)
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
mock_asc.return_value = mock_stmt
|
||||
mock_desc.return_value = mock_stmt
|
||||
mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.return_value = mock_stmt
|
||||
|
||||
# Create a properly configured mock execution
|
||||
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
configure_mock_execution(mock_execution)
|
||||
session_obj.scalars.return_value.all.return_value = [mock_execution]
|
||||
|
||||
# Create a mock domain model to be returned by _to_domain_model
|
||||
mock_domain_model = mocker.MagicMock()
|
||||
# Mock the _to_domain_model method to return our mock domain model
|
||||
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
|
||||
|
||||
# Call method
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
result = repository.get_by_workflow_execution(
|
||||
workflow_execution_id="test-workflow-run-id",
|
||||
order_config=order_config,
|
||||
)
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||
mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.assert_called_once_with(mock_stmt)
|
||||
# Assert _to_domain_model was called with the mock execution
|
||||
repository._to_domain_model.assert_called_once_with(mock_execution)
|
||||
# Assert the result contains our mock domain model
|
||||
assert len(result) == 1
|
||||
assert result[0] is mock_domain_model
|
||||
|
||||
|
||||
def test_to_db_model(repository):
|
||||
"""Test to_db_model method."""
|
||||
# Create a domain model
|
||||
domain_model = WorkflowNodeExecution(
|
||||
id="test-id",
|
||||
workflow_id="test-workflow-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
workflow_execution_id="test-workflow-run-id",
|
||||
index=1,
|
||||
predecessor_node_id="test-predecessor-id",
|
||||
node_id="test-node-id",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Test Node",
|
||||
inputs={"input_key": "input_value"},
|
||||
process_data={"process_key": "process_value"},
|
||||
outputs={"output_key": "output_value"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
error=None,
|
||||
elapsed_time=1.5,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: Decimal("0.0"),
|
||||
},
|
||||
created_at=datetime.now(),
|
||||
finished_at=None,
|
||||
)
|
||||
|
||||
# Convert to DB model
|
||||
db_model = repository._to_db_model(domain_model)
|
||||
|
||||
# Assert DB model has correct values
|
||||
assert isinstance(db_model, WorkflowNodeExecutionModel)
|
||||
assert db_model.id == domain_model.id
|
||||
assert db_model.tenant_id == repository._tenant_id
|
||||
assert db_model.app_id == repository._app_id
|
||||
assert db_model.workflow_id == domain_model.workflow_id
|
||||
assert db_model.triggered_from == repository._triggered_from
|
||||
assert db_model.workflow_run_id == domain_model.workflow_execution_id
|
||||
assert db_model.index == domain_model.index
|
||||
assert db_model.predecessor_node_id == domain_model.predecessor_node_id
|
||||
assert db_model.node_execution_id == domain_model.node_execution_id
|
||||
assert db_model.node_id == domain_model.node_id
|
||||
assert db_model.node_type == domain_model.node_type
|
||||
assert db_model.title == domain_model.title
|
||||
|
||||
assert db_model.inputs_dict == domain_model.inputs
|
||||
assert db_model.process_data_dict == domain_model.process_data
|
||||
assert db_model.outputs_dict == domain_model.outputs
|
||||
assert db_model.execution_metadata_dict == jsonable_encoder(domain_model.metadata)
|
||||
|
||||
assert db_model.status == domain_model.status
|
||||
assert db_model.error == domain_model.error
|
||||
assert db_model.elapsed_time == domain_model.elapsed_time
|
||||
assert db_model.created_at == domain_model.created_at
|
||||
assert db_model.created_by_role == repository._creator_user_role
|
||||
assert db_model.created_by == repository._creator_user_id
|
||||
assert db_model.finished_at == domain_model.finished_at
|
||||
|
||||
|
||||
def test_to_domain_model(repository):
|
||||
"""Test _to_domain_model method."""
|
||||
# Create input dictionaries
|
||||
inputs_dict = {"input_key": "input_value"}
|
||||
process_data_dict = {"process_key": "process_value"}
|
||||
outputs_dict = {"output_key": "output_value"}
|
||||
metadata_dict = {str(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS): 100}
|
||||
|
||||
# Create a DB model using our custom subclass
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "test-id"
|
||||
db_model.tenant_id = "test-tenant-id"
|
||||
db_model.app_id = "test-app-id"
|
||||
db_model.workflow_id = "test-workflow-id"
|
||||
db_model.triggered_from = "workflow-run"
|
||||
db_model.workflow_run_id = "test-workflow-run-id"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = "test-predecessor-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = BuiltinNodeTypes.START
|
||||
db_model.title = "Test Node"
|
||||
db_model.inputs = json.dumps(inputs_dict)
|
||||
db_model.process_data = json.dumps(process_data_dict)
|
||||
db_model.outputs = json.dumps(outputs_dict)
|
||||
db_model.status = WorkflowNodeExecutionStatus.RUNNING
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.5
|
||||
db_model.execution_metadata = json.dumps(metadata_dict)
|
||||
db_model.created_at = datetime.now()
|
||||
db_model.created_by_role = "account"
|
||||
db_model.created_by = "test-user-id"
|
||||
db_model.finished_at = None
|
||||
|
||||
# Convert to domain model
|
||||
domain_model = repository._to_domain_model(db_model)
|
||||
|
||||
# Assert domain model has correct values
|
||||
assert isinstance(domain_model, WorkflowNodeExecution)
|
||||
assert domain_model.id == db_model.id
|
||||
assert domain_model.workflow_id == db_model.workflow_id
|
||||
assert domain_model.workflow_execution_id == db_model.workflow_run_id
|
||||
assert domain_model.index == db_model.index
|
||||
assert domain_model.predecessor_node_id == db_model.predecessor_node_id
|
||||
assert domain_model.node_execution_id == db_model.node_execution_id
|
||||
assert domain_model.node_id == db_model.node_id
|
||||
assert domain_model.node_type == db_model.node_type
|
||||
assert domain_model.title == db_model.title
|
||||
assert domain_model.inputs == inputs_dict
|
||||
assert domain_model.process_data == process_data_dict
|
||||
assert domain_model.outputs == outputs_dict
|
||||
assert domain_model.status == WorkflowNodeExecutionStatus(db_model.status)
|
||||
assert domain_model.error == db_model.error
|
||||
assert domain_model.elapsed_time == db_model.elapsed_time
|
||||
assert domain_model.metadata == metadata_dict
|
||||
assert domain_model.created_at == db_model.created_at
|
||||
assert domain_model.finished_at == db_model.finished_at
|
||||
@ -1,106 +0,0 @@
|
||||
"""
|
||||
Unit tests for SQLAlchemyWorkflowNodeExecutionRepository, focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData:
|
||||
"""Test process_data truncation functionality in SQLAlchemyWorkflowNodeExecutionRepository."""
|
||||
|
||||
def create_mock_account(self) -> Account:
|
||||
"""Create a mock Account for testing."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = "test-user-id"
|
||||
account.tenant_id = "test-tenant-id"
|
||||
return account
|
||||
|
||||
def create_mock_session_factory(self) -> sessionmaker:
|
||||
"""Create a mock session factory for testing."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_session_factory.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.return_value.__exit__.return_value = None
|
||||
return mock_session_factory
|
||||
|
||||
def create_repository(self, mock_file_service=None) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""Create a repository instance for testing."""
|
||||
mock_account = self.create_mock_account()
|
||||
mock_session_factory = self.create_mock_session_factory()
|
||||
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if mock_file_service:
|
||||
repository._file_service = mock_file_service
|
||||
|
||||
return repository
|
||||
|
||||
def create_workflow_node_execution(
|
||||
self,
|
||||
process_data: dict[str, Any] | None = None,
|
||||
execution_id: str = "test-execution-id",
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a WorkflowNodeExecution instance for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
workflow_id="test-workflow-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="Test Node",
|
||||
process_data=process_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
def test_to_domain_model_without_offload_data(self):
|
||||
"""Test _to_domain_model without offload data."""
|
||||
repository = self.create_repository()
|
||||
|
||||
# Create mock database model without offload data
|
||||
db_model = Mock(spec=WorkflowNodeExecutionModel)
|
||||
db_model.id = "test-execution-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
db_model.workflow_id = "test-workflow-id"
|
||||
db_model.workflow_run_id = None
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = "llm"
|
||||
db_model.title = "Test Node"
|
||||
db_model.status = "succeeded"
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.5
|
||||
db_model.created_at = datetime.now()
|
||||
db_model.finished_at = None
|
||||
|
||||
process_data = {"normal": "data"}
|
||||
db_model.process_data_dict = process_data
|
||||
db_model.inputs_dict = None
|
||||
db_model.outputs_dict = None
|
||||
db_model.execution_metadata_dict = {}
|
||||
db_model.offload_data = None
|
||||
|
||||
domain_model = repository._to_domain_model(db_model)
|
||||
|
||||
# Domain model should have the data from database
|
||||
assert domain_model.process_data == process_data
|
||||
|
||||
# Should not be truncated
|
||||
assert domain_model.process_data_truncated is False
|
||||
assert domain_model.get_truncated_process_data() is None
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,825 +0,0 @@
|
||||
"""
|
||||
Comprehensive unit tests for DatasetPermissionService and DatasetService permission methods.
|
||||
|
||||
This module contains extensive unit tests for dataset permission management,
|
||||
including partial member list operations, permission validation, and permission
|
||||
enum handling.
|
||||
|
||||
The DatasetPermissionService provides methods for:
|
||||
- Retrieving partial member permissions (get_dataset_partial_member_list)
|
||||
- Updating partial member lists (update_partial_member_list)
|
||||
- Validating permissions before operations (check_permission)
|
||||
- Clearing partial member lists (clear_partial_member_list)
|
||||
|
||||
The DatasetService provides permission checking methods:
|
||||
- check_dataset_permission - validates user access to dataset
|
||||
- check_dataset_operator_permission - validates operator permissions
|
||||
|
||||
These operations are critical for dataset access control and security, ensuring
|
||||
that users can only access datasets they have permission to view or modify.
|
||||
|
||||
This test suite ensures:
|
||||
- Correct retrieval of partial member lists
|
||||
- Proper update of partial member permissions
|
||||
- Accurate permission validation logic
|
||||
- Proper handling of permission enums (only_me, all_team_members, partial_members)
|
||||
- Security boundaries are maintained
|
||||
- Error conditions are handled correctly
|
||||
|
||||
================================================================================
|
||||
ARCHITECTURE OVERVIEW
|
||||
================================================================================
|
||||
|
||||
The Dataset permission system is a multi-layered access control mechanism
|
||||
that provides fine-grained control over who can access and modify datasets.
|
||||
|
||||
1. Permission Levels:
|
||||
- only_me: Only the dataset creator can access
|
||||
- all_team_members: All members of the tenant can access
|
||||
- partial_members: Only specific users listed in DatasetPermission can access
|
||||
|
||||
2. Permission Storage:
|
||||
- Dataset.permission: Stores the permission level enum
|
||||
- DatasetPermission: Stores individual user permissions for partial_members
|
||||
- Each DatasetPermission record links a dataset to a user account
|
||||
|
||||
3. Permission Validation:
|
||||
- Tenant-level checks: Users must be in the same tenant
|
||||
- Role-based checks: OWNER role bypasses some restrictions
|
||||
- Explicit permission checks: For partial_members, explicit DatasetPermission
|
||||
records are required
|
||||
|
||||
4. Permission Operations:
|
||||
- Partial member list management: Add/remove users from partial access
|
||||
- Permission validation: Check before allowing operations
|
||||
- Permission clearing: Remove all partial members when changing permission level
|
||||
|
||||
================================================================================
|
||||
TESTING STRATEGY
|
||||
================================================================================
|
||||
|
||||
This test suite follows a comprehensive testing strategy that covers:
|
||||
|
||||
1. Partial Member List Operations:
|
||||
- Retrieving member lists
|
||||
- Adding new members
|
||||
- Updating existing members
|
||||
- Removing members
|
||||
- Empty list handling
|
||||
|
||||
2. Permission Validation:
|
||||
- Dataset editor permissions
|
||||
- Dataset operator restrictions
|
||||
- Permission enum validation
|
||||
- Partial member list validation
|
||||
- Tenant isolation
|
||||
|
||||
3. Permission Enum Handling:
|
||||
- only_me permission behavior
|
||||
- all_team_members permission behavior
|
||||
- partial_members permission behavior
|
||||
- Permission transitions
|
||||
- Edge cases for each enum value
|
||||
|
||||
4. Security and Access Control:
|
||||
- Tenant boundary enforcement
|
||||
- Role-based access control
|
||||
- Creator privilege validation
|
||||
- Explicit permission requirement
|
||||
|
||||
5. Error Handling:
|
||||
- Invalid permission changes
|
||||
- Missing required data
|
||||
- Database transaction failures
|
||||
- Permission denial scenarios
|
||||
|
||||
================================================================================
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Account, TenantAccountRole
|
||||
from models.dataset import (
|
||||
Dataset,
|
||||
DatasetPermission,
|
||||
DatasetPermissionEnum,
|
||||
)
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
# ============================================================================
|
||||
# Test Data Factory
|
||||
# ============================================================================
|
||||
# The Test Data Factory pattern is used here to centralize the creation of
|
||||
# test objects and mock instances. This approach provides several benefits:
|
||||
#
|
||||
# 1. Consistency: All test objects are created using the same factory methods,
|
||||
# ensuring consistent structure across all tests.
|
||||
#
|
||||
# 2. Maintainability: If the structure of models or services changes, we only
|
||||
# need to update the factory methods rather than every individual test.
|
||||
#
|
||||
# 3. Reusability: Factory methods can be reused across multiple test classes,
|
||||
# reducing code duplication.
|
||||
#
|
||||
# 4. Readability: Tests become more readable when they use descriptive factory
|
||||
# method calls instead of complex object construction logic.
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class DatasetPermissionTestDataFactory:
|
||||
"""
|
||||
Factory class for creating test data and mock objects for dataset permission tests.
|
||||
|
||||
This factory provides static methods to create mock objects for:
|
||||
- Dataset instances with various permission configurations
|
||||
- User/Account instances with different roles and permissions
|
||||
- DatasetPermission instances
|
||||
- Permission enum values
|
||||
- Database query results
|
||||
|
||||
The factory methods help maintain consistency across tests and reduce
|
||||
code duplication when setting up test scenarios.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
|
||||
created_by: str = "user-123",
|
||||
name: str = "Test Dataset",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock Dataset with specified attributes.
|
||||
|
||||
Args:
|
||||
dataset_id: Unique identifier for the dataset
|
||||
tenant_id: Tenant identifier
|
||||
permission: Permission level enum
|
||||
created_by: ID of user who created the dataset
|
||||
name: Dataset name
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as a Dataset instance
|
||||
"""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.permission = permission
|
||||
dataset.created_by = created_by
|
||||
dataset.name = name
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(
|
||||
user_id: str = "user-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
role: TenantAccountRole = TenantAccountRole.NORMAL,
|
||||
is_dataset_editor: bool = True,
|
||||
is_dataset_operator: bool = False,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock user (Account) with specified attributes.
|
||||
|
||||
Args:
|
||||
user_id: Unique identifier for the user
|
||||
tenant_id: Tenant identifier
|
||||
role: User role (OWNER, ADMIN, NORMAL, DATASET_OPERATOR, etc.)
|
||||
is_dataset_editor: Whether user has dataset editor permissions
|
||||
is_dataset_operator: Whether user is a dataset operator
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as an Account instance
|
||||
"""
|
||||
user = create_autospec(Account, instance=True)
|
||||
user.id = user_id
|
||||
user.current_tenant_id = tenant_id
|
||||
user.current_role = role
|
||||
user.is_dataset_editor = is_dataset_editor
|
||||
user.is_dataset_operator = is_dataset_operator
|
||||
for key, value in kwargs.items():
|
||||
setattr(user, key, value)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_permission_mock(
|
||||
permission_id: str = "permission-123",
|
||||
dataset_id: str = "dataset-123",
|
||||
account_id: str = "user-456",
|
||||
tenant_id: str = "tenant-123",
|
||||
has_permission: bool = True,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock DatasetPermission instance.
|
||||
|
||||
Args:
|
||||
permission_id: Unique identifier for the permission
|
||||
dataset_id: Dataset ID
|
||||
account_id: User account ID
|
||||
tenant_id: Tenant identifier
|
||||
has_permission: Whether permission is granted
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as a DatasetPermission instance
|
||||
"""
|
||||
permission = Mock(spec=DatasetPermission)
|
||||
permission.id = permission_id
|
||||
permission.dataset_id = dataset_id
|
||||
permission.account_id = account_id
|
||||
permission.tenant_id = tenant_id
|
||||
permission.has_permission = has_permission
|
||||
for key, value in kwargs.items():
|
||||
setattr(permission, key, value)
|
||||
return permission
|
||||
|
||||
@staticmethod
|
||||
def create_user_list_mock(user_ids: list[str]) -> list[dict[str, str]]:
|
||||
"""
|
||||
Create a list of user dictionaries for partial member list operations.
|
||||
|
||||
Args:
|
||||
user_ids: List of user IDs to include
|
||||
|
||||
Returns:
|
||||
List of user dictionaries with "user_id" keys
|
||||
"""
|
||||
return [{"user_id": user_id} for user_id in user_ids]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for check_permission
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDatasetPermissionServiceCheckPermission:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetPermissionService.check_permission method.
|
||||
|
||||
This test class covers the permission validation logic that ensures
|
||||
users have the appropriate permissions to modify dataset permissions.
|
||||
|
||||
The check_permission method:
|
||||
1. Validates user is a dataset editor
|
||||
2. Checks if dataset operator is trying to change permissions
|
||||
3. Validates partial member list when setting to partial_members
|
||||
4. Ensures dataset operators cannot change permission levels
|
||||
5. Ensures dataset operators cannot modify partial member lists
|
||||
|
||||
Test scenarios include:
|
||||
- Valid permission changes by dataset editors
|
||||
- Dataset operator restrictions
|
||||
- Partial member list validation
|
||||
- Missing dataset editor permissions
|
||||
- Invalid permission changes
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_partial_member_list(self):
|
||||
"""
|
||||
Mock get_dataset_partial_member_list method.
|
||||
|
||||
Provides a mocked version of the get_dataset_partial_member_list
|
||||
method for testing permission validation logic.
|
||||
"""
|
||||
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list") as mock_get_list:
|
||||
yield mock_get_list
|
||||
|
||||
def test_check_permission_dataset_editor_success(self, mock_get_partial_member_list):
|
||||
"""
|
||||
Test successful permission check for dataset editor.
|
||||
|
||||
Verifies that when a dataset editor (not operator) tries to
|
||||
change permissions, the check passes.
|
||||
|
||||
This test ensures:
|
||||
- Dataset editors can change permissions
|
||||
- No errors are raised for valid changes
|
||||
- Partial member list validation is skipped for non-operators
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=False)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME)
|
||||
requested_permission = DatasetPermissionEnum.ALL_TEAM
|
||||
requested_partial_member_list = None
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetPermissionService.check_permission(user, dataset, requested_permission, requested_partial_member_list)
|
||||
|
||||
# Assert
|
||||
# Verify get_partial_member_list was not called (not needed for non-operators)
|
||||
mock_get_partial_member_list.assert_not_called()
|
||||
|
||||
def test_check_permission_not_dataset_editor_error(self):
|
||||
"""
|
||||
Test error when user is not a dataset editor.
|
||||
|
||||
Verifies that when a user without dataset editor permissions
|
||||
tries to change permissions, a NoPermissionError is raised.
|
||||
|
||||
This test ensures:
|
||||
- Non-editors cannot change permissions
|
||||
- Error message is clear
|
||||
- Error type is correct
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=False)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock()
|
||||
requested_permission = DatasetPermissionEnum.ALL_TEAM
|
||||
requested_partial_member_list = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoPermissionError, match="User does not have permission to edit this dataset"):
|
||||
DatasetPermissionService.check_permission(
|
||||
user, dataset, requested_permission, requested_partial_member_list
|
||||
)
|
||||
|
||||
def test_check_permission_operator_cannot_change_permission_error(self):
|
||||
"""
|
||||
Test error when dataset operator tries to change permission level.
|
||||
|
||||
Verifies that when a dataset operator tries to change the permission
|
||||
level, a NoPermissionError is raised.
|
||||
|
||||
This test ensures:
|
||||
- Dataset operators cannot change permission levels
|
||||
- Error message is clear
|
||||
- Current permission is preserved
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=True)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME)
|
||||
requested_permission = DatasetPermissionEnum.ALL_TEAM # Trying to change
|
||||
requested_partial_member_list = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoPermissionError, match="Dataset operators cannot change the dataset permissions"):
|
||||
DatasetPermissionService.check_permission(
|
||||
user, dataset, requested_permission, requested_partial_member_list
|
||||
)
|
||||
|
||||
def test_check_permission_operator_partial_members_missing_list_error(self, mock_get_partial_member_list):
|
||||
"""
|
||||
Test error when operator sets partial_members without providing list.
|
||||
|
||||
Verifies that when a dataset operator tries to set permission to
|
||||
partial_members without providing a member list, a ValueError is raised.
|
||||
|
||||
This test ensures:
|
||||
- Partial member list is required for partial_members permission
|
||||
- Error message is clear
|
||||
- Error type is correct
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=True)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
requested_permission = "partial_members"
|
||||
requested_partial_member_list = None # Missing list
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Partial member list is required when setting to partial members"):
|
||||
DatasetPermissionService.check_permission(
|
||||
user, dataset, requested_permission, requested_partial_member_list
|
||||
)
|
||||
|
||||
def test_check_permission_operator_cannot_modify_partial_list_error(self, mock_get_partial_member_list):
|
||||
"""
|
||||
Test error when operator tries to modify partial member list.
|
||||
|
||||
Verifies that when a dataset operator tries to change the partial
|
||||
member list, a ValueError is raised.
|
||||
|
||||
This test ensures:
|
||||
- Dataset operators cannot modify partial member lists
|
||||
- Error message is clear
|
||||
- Current member list is preserved
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=True)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
requested_permission = "partial_members"
|
||||
|
||||
# Current member list
|
||||
current_member_list = ["user-456", "user-789"]
|
||||
mock_get_partial_member_list.return_value = current_member_list
|
||||
|
||||
# Requested member list (different from current)
|
||||
requested_partial_member_list = DatasetPermissionTestDataFactory.create_user_list_mock(
|
||||
["user-456", "user-999"] # Different list
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Dataset operators cannot change the dataset permissions"):
|
||||
DatasetPermissionService.check_permission(
|
||||
user, dataset, requested_permission, requested_partial_member_list
|
||||
)
|
||||
|
||||
def test_check_permission_operator_can_keep_same_partial_list(self, mock_get_partial_member_list):
|
||||
"""
|
||||
Test that operator can keep the same partial member list.
|
||||
|
||||
Verifies that when a dataset operator keeps the same partial member
|
||||
list, the check passes.
|
||||
|
||||
This test ensures:
|
||||
- Operators can keep existing partial member lists
|
||||
- No errors are raised for unchanged lists
|
||||
- Permission validation works correctly
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=True)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
requested_permission = "partial_members"
|
||||
|
||||
# Current member list
|
||||
current_member_list = ["user-456", "user-789"]
|
||||
mock_get_partial_member_list.return_value = current_member_list
|
||||
|
||||
# Requested member list (same as current)
|
||||
requested_partial_member_list = DatasetPermissionTestDataFactory.create_user_list_mock(
|
||||
["user-456", "user-789"] # Same list
|
||||
)
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetPermissionService.check_permission(user, dataset, requested_permission, requested_partial_member_list)
|
||||
|
||||
# Assert
|
||||
# Verify get_partial_member_list was called to compare lists
|
||||
mock_get_partial_member_list.assert_called_once_with(dataset.id)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for DatasetService.check_dataset_permission
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDatasetServiceCheckDatasetPermission:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.check_dataset_permission method.
|
||||
|
||||
This test class covers the dataset permission checking logic that validates
|
||||
whether a user has access to a dataset based on permission enums.
|
||||
|
||||
The check_dataset_permission method:
|
||||
1. Validates tenant match
|
||||
2. Checks OWNER role (bypasses some restrictions)
|
||||
3. Validates only_me permission (creator only)
|
||||
4. Validates partial_members permission (explicit permission required)
|
||||
5. Validates all_team_members permission (all tenant members)
|
||||
|
||||
Test scenarios include:
|
||||
- Tenant boundary enforcement
|
||||
- OWNER role bypass
|
||||
- only_me permission validation
|
||||
- partial_members permission validation
|
||||
- all_team_members permission validation
|
||||
- Permission denial scenarios
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Mock database session for testing.
|
||||
|
||||
Provides a mocked database session that can be used to verify
|
||||
database queries for permission checks.
|
||||
"""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_check_dataset_permission_owner_bypass(self, mock_db_session):
|
||||
"""
|
||||
Test that OWNER role bypasses permission checks.
|
||||
|
||||
Verifies that when a user has OWNER role, they can access any
|
||||
dataset in their tenant regardless of permission level.
|
||||
|
||||
This test ensures:
|
||||
- OWNER role bypasses permission restrictions
|
||||
- No database queries are needed for OWNER
|
||||
- Access is granted automatically
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(role=TenantAccountRole.OWNER, tenant_id="tenant-123")
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
tenant_id="tenant-123",
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
created_by="other-user-123", # Not the current user
|
||||
)
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
# Assert
|
||||
# Verify no permission queries were made (OWNER bypasses)
|
||||
mock_db_session.query.assert_not_called()
|
||||
|
||||
def test_check_dataset_permission_tenant_mismatch_error(self):
|
||||
"""
|
||||
Test error when user and dataset are in different tenants.
|
||||
|
||||
Verifies that when a user tries to access a dataset from a different
|
||||
tenant, a NoPermissionError is raised.
|
||||
|
||||
This test ensures:
|
||||
- Tenant boundary is enforced
|
||||
- Error message is clear
|
||||
- Error type is correct
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(tenant_id="tenant-123")
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(tenant_id="tenant-456") # Different tenant
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
def test_check_dataset_permission_only_me_creator_success(self):
|
||||
"""
|
||||
Test that creator can access only_me dataset.
|
||||
|
||||
Verifies that when a user is the creator of an only_me dataset,
|
||||
they can access it successfully.
|
||||
|
||||
This test ensures:
|
||||
- Creators can access their own only_me datasets
|
||||
- No explicit permission record is needed
|
||||
- Access is granted correctly
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
tenant_id="tenant-123",
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
created_by="user-123", # User is the creator
|
||||
)
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
def test_check_dataset_permission_only_me_non_creator_error(self):
|
||||
"""
|
||||
Test error when non-creator tries to access only_me dataset.
|
||||
|
||||
Verifies that when a user who is not the creator tries to access
|
||||
an only_me dataset, a NoPermissionError is raised.
|
||||
|
||||
This test ensures:
|
||||
- Non-creators cannot access only_me datasets
|
||||
- Error message is clear
|
||||
- Error type is correct
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
tenant_id="tenant-123",
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
created_by="other-user-456", # Different creator
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
def test_check_dataset_permission_partial_members_creator_success(self, mock_db_session):
|
||||
"""
|
||||
Test that creator can access partial_members dataset without explicit permission.
|
||||
|
||||
Verifies that when a user is the creator of a partial_members dataset,
|
||||
they can access it even without an explicit DatasetPermission record.
|
||||
|
||||
This test ensures:
|
||||
- Creators can access their own datasets
|
||||
- No explicit permission record is needed for creators
|
||||
- Access is granted correctly
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
tenant_id="tenant-123",
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
created_by="user-123", # User is the creator
|
||||
)
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
# Assert
|
||||
# Verify permission query was not executed (creator bypasses)
|
||||
mock_db_session.query.assert_not_called()
|
||||
|
||||
def test_check_dataset_permission_all_team_members_success(self):
|
||||
"""
|
||||
Test that any tenant member can access all_team_members dataset.
|
||||
|
||||
Verifies that when a dataset has all_team_members permission, any
|
||||
user in the same tenant can access it.
|
||||
|
||||
This test ensures:
|
||||
- All team members can access
|
||||
- No explicit permission record is needed
|
||||
- Access is granted correctly
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
tenant_id="tenant-123",
|
||||
permission=DatasetPermissionEnum.ALL_TEAM,
|
||||
created_by="other-user-456", # Not the creator
|
||||
)
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for DatasetService.check_dataset_operator_permission
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDatasetServiceCheckDatasetOperatorPermission:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.check_dataset_operator_permission method.
|
||||
|
||||
This test class covers the dataset operator permission checking logic,
|
||||
which validates whether a dataset operator has access to a dataset.
|
||||
|
||||
The check_dataset_operator_permission method:
|
||||
1. Validates dataset exists
|
||||
2. Validates user exists
|
||||
3. Checks OWNER role (bypasses restrictions)
|
||||
4. Validates only_me permission (creator only)
|
||||
5. Validates partial_members permission (explicit permission required)
|
||||
|
||||
Test scenarios include:
|
||||
- Dataset not found error
|
||||
- User not found error
|
||||
- OWNER role bypass
|
||||
- only_me permission validation
|
||||
- partial_members permission validation
|
||||
- Permission denial scenarios
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Mock database session for testing.
|
||||
|
||||
Provides a mocked database session that can be used to verify
|
||||
database queries for permission checks.
|
||||
"""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_check_dataset_operator_permission_dataset_not_found_error(self):
|
||||
"""
|
||||
Test error when dataset is None.
|
||||
|
||||
Verifies that when dataset is None, a ValueError is raised.
|
||||
|
||||
This test ensures:
|
||||
- Dataset existence is validated
|
||||
- Error message is clear
|
||||
- Error type is correct
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock()
|
||||
dataset = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
|
||||
|
||||
def test_check_dataset_operator_permission_user_not_found_error(self):
|
||||
"""
|
||||
Test error when user is None.
|
||||
|
||||
Verifies that when user is None, a ValueError is raised.
|
||||
|
||||
This test ensures:
|
||||
- User existence is validated
|
||||
- Error message is clear
|
||||
- Error type is correct
|
||||
"""
|
||||
# Arrange
|
||||
user = None
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="User not found"):
|
||||
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
|
||||
|
||||
def test_check_dataset_operator_permission_owner_bypass(self):
|
||||
"""
|
||||
Test that OWNER role bypasses permission checks.
|
||||
|
||||
Verifies that when a user has OWNER role, they can access any
|
||||
dataset in their tenant regardless of permission level.
|
||||
|
||||
This test ensures:
|
||||
- OWNER role bypasses permission restrictions
|
||||
- No database queries are needed for OWNER
|
||||
- Access is granted automatically
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(role=TenantAccountRole.OWNER, tenant_id="tenant-123")
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
tenant_id="tenant-123",
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
created_by="other-user-123", # Not the current user
|
||||
)
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
|
||||
|
||||
def test_check_dataset_operator_permission_only_me_creator_success(self):
|
||||
"""
|
||||
Test that creator can access only_me dataset.
|
||||
|
||||
Verifies that when a user is the creator of an only_me dataset,
|
||||
they can access it successfully.
|
||||
|
||||
This test ensures:
|
||||
- Creators can access their own only_me datasets
|
||||
- No explicit permission record is needed
|
||||
- Access is granted correctly
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
tenant_id="tenant-123",
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
created_by="user-123", # User is the creator
|
||||
)
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
|
||||
|
||||
def test_check_dataset_operator_permission_only_me_non_creator_error(self):
|
||||
"""
|
||||
Test error when non-creator tries to access only_me dataset.
|
||||
|
||||
Verifies that when a user who is not the creator tries to access
|
||||
an only_me dataset, a NoPermissionError is raised.
|
||||
|
||||
This test ensures:
|
||||
- Non-creators cannot access only_me datasets
|
||||
- Error message is clear
|
||||
- Error type is correct
|
||||
"""
|
||||
# Arrange
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
tenant_id="tenant-123",
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
created_by="other-user-456", # Different creator
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
|
||||
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Additional Documentation and Notes
|
||||
# ============================================================================
|
||||
#
|
||||
# This test suite covers the core permission management operations for datasets.
|
||||
# Additional test scenarios that could be added:
|
||||
#
|
||||
# 1. Permission Enum Transitions:
|
||||
# - Testing transitions between permission levels
|
||||
# - Testing validation during transitions
|
||||
# - Testing partial member list updates during transitions
|
||||
#
|
||||
# 2. Bulk Operations:
|
||||
# - Testing bulk permission updates
|
||||
# - Testing bulk partial member list updates
|
||||
# - Testing performance with large member lists
|
||||
#
|
||||
# 3. Edge Cases:
|
||||
# - Testing with very large partial member lists
|
||||
# - Testing with special characters in user IDs
|
||||
# - Testing with deleted users
|
||||
# - Testing with inactive permissions
|
||||
#
|
||||
# 4. Integration Scenarios:
|
||||
# - Testing permission changes followed by access attempts
|
||||
# - Testing concurrent permission updates
|
||||
# - Testing permission inheritance
|
||||
#
|
||||
# These scenarios are not currently implemented but could be added if needed
|
||||
# based on real-world usage patterns or discovered edge cases.
|
||||
#
|
||||
# ============================================================================
|
||||
@ -1,818 +0,0 @@
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService update and delete operations.
|
||||
|
||||
This module contains extensive unit tests for the DatasetService class,
|
||||
specifically focusing on update and delete operations for datasets.
|
||||
|
||||
The DatasetService provides methods for:
|
||||
- Updating dataset configuration and settings (update_dataset)
|
||||
- Deleting datasets with proper cleanup (delete_dataset)
|
||||
- Updating RAG pipeline dataset settings (update_rag_pipeline_dataset_settings)
|
||||
- Checking if dataset is in use (dataset_use_check)
|
||||
- Updating dataset API access status (update_dataset_api_status)
|
||||
|
||||
These operations are critical for dataset lifecycle management and require
|
||||
careful handling of permissions, dependencies, and data integrity.
|
||||
|
||||
This test suite ensures:
|
||||
- Correct update of dataset properties
|
||||
- Proper permission validation before updates/deletes
|
||||
- Cascade deletion handling
|
||||
- Event signaling for cleanup operations
|
||||
- RAG pipeline dataset configuration updates
|
||||
- API status management
|
||||
- Use check validation
|
||||
|
||||
================================================================================
|
||||
ARCHITECTURE OVERVIEW
|
||||
================================================================================
|
||||
|
||||
The DatasetService update and delete operations are part of the dataset
|
||||
lifecycle management system. These operations interact with multiple
|
||||
components:
|
||||
|
||||
1. Permission System: All update/delete operations require proper
|
||||
permission validation to ensure users can only modify datasets they
|
||||
have access to.
|
||||
|
||||
2. Event System: Dataset deletion triggers the dataset_was_deleted event,
|
||||
which notifies other components to clean up related data (documents,
|
||||
segments, vector indices, etc.).
|
||||
|
||||
3. Dependency Checking: Before deletion, the system checks if the dataset
|
||||
is in use by any applications (via AppDatasetJoin).
|
||||
|
||||
4. RAG Pipeline Integration: RAG pipeline datasets have special update
|
||||
logic that handles chunk structure, indexing techniques, and embedding
|
||||
model configuration.
|
||||
|
||||
5. API Status Management: Datasets can have their API access enabled or
|
||||
disabled, which affects whether they can be accessed via the API.
|
||||
|
||||
================================================================================
|
||||
TESTING STRATEGY
|
||||
================================================================================
|
||||
|
||||
This test suite follows a comprehensive testing strategy that covers:
|
||||
|
||||
1. Update Operations:
|
||||
- Internal dataset updates
|
||||
- External dataset updates
|
||||
- RAG pipeline dataset updates
|
||||
- Permission validation
|
||||
- Name duplicate checking
|
||||
- Configuration validation
|
||||
|
||||
2. Delete Operations:
|
||||
- Successful deletion
|
||||
- Permission validation
|
||||
- Event signaling
|
||||
- Database cleanup
|
||||
- Not found handling
|
||||
|
||||
3. Use Check Operations:
|
||||
- Dataset in use detection
|
||||
- Dataset not in use detection
|
||||
- AppDatasetJoin query validation
|
||||
|
||||
4. API Status Operations:
|
||||
- Enable API access
|
||||
- Disable API access
|
||||
- Permission validation
|
||||
- Current user validation
|
||||
|
||||
5. RAG Pipeline Operations:
|
||||
- Unpublished dataset updates
|
||||
- Published dataset updates
|
||||
- Chunk structure validation
|
||||
- Indexing technique changes
|
||||
- Embedding model configuration
|
||||
|
||||
================================================================================
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models import Account, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
Dataset,
|
||||
DatasetPermissionEnum,
|
||||
)
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
# ============================================================================
|
||||
# Test Data Factory
|
||||
# ============================================================================
|
||||
# The Test Data Factory pattern is used here to centralize the creation of
|
||||
# test objects and mock instances. This approach provides several benefits:
|
||||
#
|
||||
# 1. Consistency: All test objects are created using the same factory methods,
|
||||
# ensuring consistent structure across all tests.
|
||||
#
|
||||
# 2. Maintainability: If the structure of models or services changes, we only
|
||||
# need to update the factory methods rather than every individual test.
|
||||
#
|
||||
# 3. Reusability: Factory methods can be reused across multiple test classes,
|
||||
# reducing code duplication.
|
||||
#
|
||||
# 4. Readability: Tests become more readable when they use descriptive factory
|
||||
# method calls instead of complex object construction logic.
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class DatasetUpdateDeleteTestDataFactory:
|
||||
"""
|
||||
Factory class for creating test data and mock objects for dataset update/delete tests.
|
||||
|
||||
This factory provides static methods to create mock objects for:
|
||||
- Dataset instances with various configurations
|
||||
- User/Account instances with different roles
|
||||
- Knowledge configuration objects
|
||||
- Database session mocks
|
||||
- Event signal mocks
|
||||
|
||||
The factory methods help maintain consistency across tests and reduce
|
||||
code duplication when setting up test scenarios.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
provider: str = "vendor",
|
||||
name: str = "Test Dataset",
|
||||
description: str = "Test description",
|
||||
tenant_id: str = "tenant-123",
|
||||
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider: str | None = "openai",
|
||||
embedding_model: str | None = "text-embedding-ada-002",
|
||||
collection_binding_id: str | None = "binding-123",
|
||||
enable_api: bool = True,
|
||||
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
|
||||
created_by: str = "user-123",
|
||||
chunk_structure: str | None = None,
|
||||
runtime_mode: str = "general",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock Dataset with specified attributes.
|
||||
|
||||
Args:
|
||||
dataset_id: Unique identifier for the dataset
|
||||
provider: Dataset provider (vendor, external)
|
||||
name: Dataset name
|
||||
description: Dataset description
|
||||
tenant_id: Tenant identifier
|
||||
indexing_technique: Indexing technique (high_quality, economy)
|
||||
embedding_model_provider: Embedding model provider
|
||||
embedding_model: Embedding model name
|
||||
collection_binding_id: Collection binding ID
|
||||
enable_api: Whether API access is enabled
|
||||
permission: Dataset permission level
|
||||
created_by: ID of user who created the dataset
|
||||
chunk_structure: Chunk structure for RAG pipeline datasets
|
||||
runtime_mode: Runtime mode (general, rag_pipeline)
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as a Dataset instance
|
||||
"""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.provider = provider
|
||||
dataset.name = name
|
||||
dataset.description = description
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.indexing_technique = indexing_technique
|
||||
dataset.embedding_model_provider = embedding_model_provider
|
||||
dataset.embedding_model = embedding_model
|
||||
dataset.collection_binding_id = collection_binding_id
|
||||
dataset.enable_api = enable_api
|
||||
dataset.permission = permission
|
||||
dataset.created_by = created_by
|
||||
dataset.chunk_structure = chunk_structure
|
||||
dataset.runtime_mode = runtime_mode
|
||||
dataset.retrieval_model = {}
|
||||
dataset.keyword_number = 10
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(
|
||||
user_id: str = "user-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
role: TenantAccountRole = TenantAccountRole.NORMAL,
|
||||
is_dataset_editor: bool = True,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock user (Account) with specified attributes.
|
||||
|
||||
Args:
|
||||
user_id: Unique identifier for the user
|
||||
tenant_id: Tenant identifier
|
||||
role: User role (OWNER, ADMIN, NORMAL, etc.)
|
||||
is_dataset_editor: Whether user has dataset editor permissions
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as an Account instance
|
||||
"""
|
||||
user = create_autospec(Account, instance=True)
|
||||
user.id = user_id
|
||||
user.current_tenant_id = tenant_id
|
||||
user.current_role = role
|
||||
user.is_dataset_editor = is_dataset_editor
|
||||
for key, value in kwargs.items():
|
||||
setattr(user, key, value)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_knowledge_configuration_mock(
|
||||
chunk_structure: str = "tree",
|
||||
indexing_technique: str = IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-ada-002",
|
||||
keyword_number: int = 10,
|
||||
retrieval_model: dict | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock KnowledgeConfiguration entity.
|
||||
|
||||
Args:
|
||||
chunk_structure: Chunk structure type
|
||||
indexing_technique: Indexing technique
|
||||
embedding_model_provider: Embedding model provider
|
||||
embedding_model: Embedding model name
|
||||
keyword_number: Keyword number for economy indexing
|
||||
retrieval_model: Retrieval model configuration
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as a KnowledgeConfiguration instance
|
||||
"""
|
||||
config = Mock()
|
||||
config.chunk_structure = chunk_structure
|
||||
config.indexing_technique = indexing_technique
|
||||
config.embedding_model_provider = embedding_model_provider
|
||||
config.embedding_model = embedding_model
|
||||
config.keyword_number = keyword_number
|
||||
config.retrieval_model = Mock()
|
||||
config.retrieval_model.model_dump.return_value = retrieval_model or {
|
||||
"search_method": "semantic_search",
|
||||
"top_k": 2,
|
||||
}
|
||||
for key, value in kwargs.items():
|
||||
setattr(config, key, value)
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def create_app_dataset_join_mock(
|
||||
app_id: str = "app-123",
|
||||
dataset_id: str = "dataset-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock AppDatasetJoin instance.
|
||||
|
||||
Args:
|
||||
app_id: Application ID
|
||||
dataset_id: Dataset ID
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as an AppDatasetJoin instance
|
||||
"""
|
||||
join = Mock(spec=AppDatasetJoin)
|
||||
join.app_id = app_id
|
||||
join.dataset_id = dataset_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(join, key, value)
|
||||
return join
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for update_dataset
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDatasetServiceUpdateDataset:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.update_dataset method.
|
||||
|
||||
This test class covers the dataset update functionality, including
|
||||
internal and external dataset updates, permission validation, and
|
||||
name duplicate checking.
|
||||
|
||||
The update_dataset method:
|
||||
1. Retrieves the dataset by ID
|
||||
2. Validates dataset exists
|
||||
3. Checks for duplicate names
|
||||
4. Validates user permissions
|
||||
5. Routes to appropriate update handler (internal or external)
|
||||
6. Returns the updated dataset
|
||||
|
||||
Test scenarios include:
|
||||
- Successful internal dataset updates
|
||||
- Successful external dataset updates
|
||||
- Permission validation
|
||||
- Duplicate name detection
|
||||
- Dataset not found errors
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""
|
||||
Mock dataset service dependencies for testing.
|
||||
|
||||
Provides mocked dependencies including:
|
||||
- get_dataset method
|
||||
- check_dataset_permission method
|
||||
- _has_dataset_same_name method
|
||||
- Database session
|
||||
- Current time utilities
|
||||
"""
|
||||
with (
|
||||
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
|
||||
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
||||
patch("services.dataset_service.DatasetService._has_dataset_same_name") as mock_has_same_name,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
||||
):
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
mock_naive_utc_now.return_value = current_time
|
||||
|
||||
yield {
|
||||
"get_dataset": mock_get_dataset,
|
||||
"check_permission": mock_check_perm,
|
||||
"has_same_name": mock_has_same_name,
|
||||
"db_session": mock_db,
|
||||
"naive_utc_now": mock_naive_utc_now,
|
||||
"current_time": current_time,
|
||||
}
|
||||
|
||||
def test_update_dataset_internal_success(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test successful update of an internal dataset.
|
||||
|
||||
Verifies that when all validation passes, an internal dataset
|
||||
is updated correctly through the _update_internal_dataset method.
|
||||
|
||||
This test ensures:
|
||||
- Dataset is retrieved correctly
|
||||
- Permission is checked
|
||||
- Name duplicate check is performed
|
||||
- Internal update handler is called
|
||||
- Updated dataset is returned
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "dataset-123"
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(
|
||||
dataset_id=dataset_id, provider="vendor", name="Old Name"
|
||||
)
|
||||
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {
|
||||
"name": "New Name",
|
||||
"description": "New Description",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
mock_dataset_service_dependencies["has_same_name"].return_value = False
|
||||
|
||||
with patch("services.dataset_service.DatasetService._update_internal_dataset") as mock_update_internal:
|
||||
mock_update_internal.return_value = dataset
|
||||
|
||||
# Act
|
||||
result = DatasetService.update_dataset(dataset_id, update_data, user)
|
||||
|
||||
# Assert
|
||||
assert result == dataset
|
||||
|
||||
# Verify dataset was retrieved
|
||||
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id)
|
||||
|
||||
# Verify permission was checked
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
|
||||
# Verify name duplicate check was performed
|
||||
mock_dataset_service_dependencies["has_same_name"].assert_called_once()
|
||||
|
||||
# Verify internal update handler was called
|
||||
mock_update_internal.assert_called_once()
|
||||
|
||||
def test_update_dataset_external_success(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test successful update of an external dataset.
|
||||
|
||||
Verifies that when all validation passes, an external dataset
|
||||
is updated correctly through the _update_external_dataset method.
|
||||
|
||||
This test ensures:
|
||||
- Dataset is retrieved correctly
|
||||
- Permission is checked
|
||||
- Name duplicate check is performed
|
||||
- External update handler is called
|
||||
- Updated dataset is returned
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "dataset-123"
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(
|
||||
dataset_id=dataset_id, provider="external", name="Old Name"
|
||||
)
|
||||
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {
|
||||
"name": "New Name",
|
||||
"external_knowledge_id": "new-knowledge-id",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
mock_dataset_service_dependencies["has_same_name"].return_value = False
|
||||
|
||||
with patch("services.dataset_service.DatasetService._update_external_dataset") as mock_update_external:
|
||||
mock_update_external.return_value = dataset
|
||||
|
||||
# Act
|
||||
result = DatasetService.update_dataset(dataset_id, update_data, user)
|
||||
|
||||
# Assert
|
||||
assert result == dataset
|
||||
|
||||
# Verify external update handler was called
|
||||
mock_update_external.assert_called_once()
|
||||
|
||||
def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test error handling when dataset is not found.
|
||||
|
||||
Verifies that when the dataset ID doesn't exist, a ValueError
|
||||
is raised with an appropriate message.
|
||||
|
||||
This test ensures:
|
||||
- Dataset not found error is handled correctly
|
||||
- No update operations are performed
|
||||
- Error message is clear
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "non-existent-dataset"
|
||||
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"name": "New Name"}
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
DatasetService.update_dataset(dataset_id, update_data, user)
|
||||
|
||||
# Verify no update operations were attempted
|
||||
mock_dataset_service_dependencies["check_permission"].assert_not_called()
|
||||
mock_dataset_service_dependencies["has_same_name"].assert_not_called()
|
||||
|
||||
def test_update_dataset_duplicate_name_error(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test error handling when dataset name already exists.
|
||||
|
||||
Verifies that when a dataset with the same name already exists
|
||||
in the tenant, a ValueError is raised.
|
||||
|
||||
This test ensures:
|
||||
- Duplicate name detection works correctly
|
||||
- Error message is clear
|
||||
- No update operations are performed
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "dataset-123"
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
|
||||
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"name": "Existing Name"}
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
mock_dataset_service_dependencies["has_same_name"].return_value = True # Duplicate exists
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Dataset name already exists"):
|
||||
DatasetService.update_dataset(dataset_id, update_data, user)
|
||||
|
||||
# Verify permission check was not called (fails before that)
|
||||
mock_dataset_service_dependencies["check_permission"].assert_not_called()
|
||||
|
||||
def test_update_dataset_permission_denied_error(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test error handling when user lacks permission.
|
||||
|
||||
Verifies that when the user doesn't have permission to update
|
||||
the dataset, a NoPermissionError is raised.
|
||||
|
||||
This test ensures:
|
||||
- Permission validation works correctly
|
||||
- Error is raised before any updates
|
||||
- Error type is correct
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "dataset-123"
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
|
||||
user = DatasetUpdateDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"name": "New Name"}
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
mock_dataset_service_dependencies["has_same_name"].return_value = False
|
||||
mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.update_dataset(dataset_id, update_data, user)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for update_rag_pipeline_dataset_settings
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDatasetServiceUpdateRagPipelineDatasetSettings:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.update_rag_pipeline_dataset_settings method.
|
||||
|
||||
This test class covers the RAG pipeline dataset settings update functionality,
|
||||
including chunk structure, indexing technique, and embedding model configuration.
|
||||
|
||||
The update_rag_pipeline_dataset_settings method:
|
||||
1. Validates current_user and tenant
|
||||
2. Merges dataset into session
|
||||
3. Handles unpublished vs published datasets differently
|
||||
4. Updates chunk structure, indexing technique, and retrieval model
|
||||
5. Configures embedding model for high_quality indexing
|
||||
6. Updates keyword_number for economy indexing
|
||||
7. Commits transaction
|
||||
8. Triggers index update tasks if needed
|
||||
|
||||
Test scenarios include:
|
||||
- Unpublished dataset updates
|
||||
- Published dataset updates
|
||||
- Chunk structure validation
|
||||
- Indexing technique changes
|
||||
- Embedding model configuration
|
||||
- Error handling
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""
|
||||
Mock database session for testing.
|
||||
|
||||
Provides a mocked SQLAlchemy session for testing session operations.
|
||||
"""
|
||||
return Mock(spec=Session)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""
|
||||
Mock dataset service dependencies for testing.
|
||||
|
||||
Provides mocked dependencies including:
|
||||
- current_user context
|
||||
- ModelManager
|
||||
- DatasetCollectionBindingService
|
||||
- Database session operations
|
||||
- Task scheduling
|
||||
"""
|
||||
with (
|
||||
patch(
|
||||
"services.dataset_service.current_user", create_autospec(Account, instance=True)
|
||||
) as mock_current_user,
|
||||
patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager,
|
||||
patch(
|
||||
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
|
||||
) as mock_get_binding,
|
||||
patch("services.dataset_service.deal_dataset_index_update_task") as mock_task,
|
||||
):
|
||||
mock_current_user.current_tenant_id = "tenant-123"
|
||||
mock_current_user.id = "user-123"
|
||||
|
||||
yield {
|
||||
"current_user": mock_current_user,
|
||||
"model_manager": mock_model_manager,
|
||||
"get_binding": mock_get_binding,
|
||||
"task": mock_task,
|
||||
}
|
||||
|
||||
def test_update_rag_pipeline_dataset_settings_unpublished_success(
|
||||
self, mock_session, mock_dataset_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful update of unpublished RAG pipeline dataset.
|
||||
|
||||
Verifies that when a dataset is not published, all settings can
|
||||
be updated including chunk structure and indexing technique.
|
||||
|
||||
This test ensures:
|
||||
- Current user validation passes
|
||||
- Dataset is merged into session
|
||||
- Chunk structure is updated
|
||||
- Indexing technique is updated
|
||||
- Embedding model is configured for high_quality
|
||||
- Retrieval model is updated
|
||||
- Dataset is added to session
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-123",
|
||||
runtime_mode="rag_pipeline",
|
||||
chunk_structure="tree",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
|
||||
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
|
||||
chunk_structure="list",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
# Mock embedding model
|
||||
mock_embedding_model = Mock()
|
||||
mock_embedding_model.model_name = "text-embedding-ada-002"
|
||||
mock_embedding_model.provider = "openai"
|
||||
mock_embedding_model.credentials = {}
|
||||
|
||||
mock_model_schema = Mock()
|
||||
mock_model_schema.features = []
|
||||
|
||||
mock_text_embedding_model = Mock()
|
||||
mock_text_embedding_model.get_model_schema.return_value = mock_model_schema
|
||||
mock_embedding_model.model_type_instance = mock_text_embedding_model
|
||||
|
||||
mock_model_instance = Mock()
|
||||
mock_model_instance.get_model_instance.return_value = mock_embedding_model
|
||||
mock_dataset_service_dependencies["model_manager"].return_value = mock_model_instance
|
||||
|
||||
# Mock collection binding
|
||||
mock_binding = Mock()
|
||||
mock_binding.id = "binding-123"
|
||||
mock_dataset_service_dependencies["get_binding"].return_value = mock_binding
|
||||
|
||||
mock_session.merge.return_value = dataset
|
||||
|
||||
# Act
|
||||
DatasetService.update_rag_pipeline_dataset_settings(
|
||||
mock_session, dataset, knowledge_config, has_published=False
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert dataset.chunk_structure == "list"
|
||||
assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.collection_binding_id == "binding-123"
|
||||
|
||||
# Verify dataset was added to session
|
||||
mock_session.add.assert_called_once_with(dataset)
|
||||
|
||||
def test_update_rag_pipeline_dataset_settings_published_chunk_structure_error(
|
||||
self, mock_session, mock_dataset_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when trying to update chunk structure of published dataset.
|
||||
|
||||
Verifies that when a dataset is published and has an existing chunk structure,
|
||||
attempting to change it raises a ValueError.
|
||||
|
||||
This test ensures:
|
||||
- Chunk structure change is detected
|
||||
- ValueError is raised with appropriate message
|
||||
- No updates are committed
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-123",
|
||||
runtime_mode="rag_pipeline",
|
||||
chunk_structure="tree", # Existing structure
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
|
||||
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
|
||||
chunk_structure="list", # Different structure
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
|
||||
mock_session.merge.return_value = dataset
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Chunk structure is not allowed to be updated"):
|
||||
DatasetService.update_rag_pipeline_dataset_settings(
|
||||
mock_session, dataset, knowledge_config, has_published=True
|
||||
)
|
||||
|
||||
# Verify no commit was attempted
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
def test_update_rag_pipeline_dataset_settings_published_economy_error(
|
||||
self, mock_session, mock_dataset_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when trying to change to economy indexing on published dataset.
|
||||
|
||||
Verifies that when a dataset is published, changing indexing technique to
|
||||
economy is not allowed and raises a ValueError.
|
||||
|
||||
This test ensures:
|
||||
- Economy indexing change is detected
|
||||
- ValueError is raised with appropriate message
|
||||
- No updates are committed
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-123",
|
||||
runtime_mode="rag_pipeline",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY, # Current technique
|
||||
)
|
||||
|
||||
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock(
|
||||
indexing_technique=IndexTechniqueType.ECONOMY, # Trying to change to economy
|
||||
)
|
||||
|
||||
mock_session.merge.return_value = dataset
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(
|
||||
ValueError, match="Knowledge base indexing technique is not allowed to be updated to economy"
|
||||
):
|
||||
DatasetService.update_rag_pipeline_dataset_settings(
|
||||
mock_session, dataset, knowledge_config, has_published=True
|
||||
)
|
||||
|
||||
def test_update_rag_pipeline_dataset_settings_missing_current_user_error(
|
||||
self, mock_session, mock_dataset_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when current_user is missing.
|
||||
|
||||
Verifies that when current_user is None or has no tenant ID, a ValueError
|
||||
is raised.
|
||||
|
||||
This test ensures:
|
||||
- Current user validation works correctly
|
||||
- Error message is clear
|
||||
- No updates are performed
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock()
|
||||
knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock()
|
||||
|
||||
mock_dataset_service_dependencies["current_user"].current_tenant_id = None # Missing tenant
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Current user or current tenant not found"):
|
||||
DatasetService.update_rag_pipeline_dataset_settings(
|
||||
mock_session, dataset, knowledge_config, has_published=False
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Additional Documentation and Notes
|
||||
# ============================================================================
|
||||
#
|
||||
# This test suite covers the core update and delete operations for datasets.
|
||||
# Additional test scenarios that could be added:
|
||||
#
|
||||
# 1. Update Operations:
|
||||
# - Testing with different indexing techniques
|
||||
# - Testing embedding model provider changes
|
||||
# - Testing retrieval model updates
|
||||
# - Testing icon_info updates
|
||||
# - Testing partial_member_list updates
|
||||
#
|
||||
# 2. Delete Operations:
|
||||
# - Testing cascade deletion of related data
|
||||
# - Testing event handler execution
|
||||
# - Testing with datasets that have documents
|
||||
# - Testing with datasets that have segments
|
||||
#
|
||||
# 3. RAG Pipeline Operations:
|
||||
# - Testing economy indexing technique updates
|
||||
# - Testing embedding model provider errors
|
||||
# - Testing keyword_number updates
|
||||
# - Testing index update task triggering
|
||||
#
|
||||
# 4. Integration Scenarios:
|
||||
# - Testing update followed by delete
|
||||
# - Testing multiple updates in sequence
|
||||
# - Testing concurrent update attempts
|
||||
# - Testing with different user roles
|
||||
#
|
||||
# These scenarios are not currently implemented but could be added if needed
|
||||
# based on real-world usage patterns or discovered edge cases.
|
||||
#
|
||||
# ============================================================================
|
||||
@ -247,10 +247,11 @@ workflow:
|
||||
dataset_mock = Mock()
|
||||
dataset_mock.id = "d1"
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
|
||||
session = cast(MagicMock, Mock())
|
||||
service = RagPipelineDslService(session=cast(Session, session))
|
||||
session.query.return_value.filter_by.return_value.all.return_value = []
|
||||
session.scalars.return_value.all.return_value = []
|
||||
account = Mock(current_tenant_id="t1")
|
||||
|
||||
result = service.import_rag_pipeline(account=account, import_mode="yaml-content", yaml_content=yaml_content)
|
||||
@ -320,6 +321,7 @@ workflow:
|
||||
dataset_mock.id = "d1"
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.DatasetCollectionBinding", return_value=Mock(id="b1"))
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
|
||||
service = RagPipelineDslService(session=Mock())
|
||||
# Mocking self._session.scalar for the pipeline lookup
|
||||
@ -406,12 +408,14 @@ def test_create_or_update_pipeline_create_new(mocker) -> None:
|
||||
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", SimpleNamespace(id="u1"))
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow", return_value=Mock())
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
pipeline_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Pipeline")
|
||||
pipeline_instance = pipeline_cls.return_value
|
||||
pipeline_instance.tenant_id = "t1"
|
||||
pipeline_instance.id = "p1"
|
||||
pipeline_instance.name = "P"
|
||||
pipeline_instance.is_published = False
|
||||
session.scalar.return_value = None
|
||||
|
||||
result = service._create_or_update_pipeline(pipeline=None, data=data, account=account, dependencies=[])
|
||||
|
||||
@ -447,8 +451,7 @@ def test_export_rag_pipeline_dsl_with_workflow(mocker) -> None:
|
||||
workflow.rag_pipeline_variables = []
|
||||
workflow.to_dict.return_value = {"graph": {"nodes": []}}
|
||||
|
||||
# Mocking single .where() call
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
|
||||
return_value=[],
|
||||
@ -550,7 +553,7 @@ def test_append_workflow_export_data_filters_credentials(mocker) -> None:
|
||||
]
|
||||
}
|
||||
}
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
|
||||
return_value=[],
|
||||
@ -568,7 +571,7 @@ def test_append_workflow_export_data_filters_credentials(mocker) -> None:
|
||||
def test_create_rag_pipeline_dataset_raises_when_name_conflicts(mocker) -> None:
|
||||
session = cast(MagicMock, Mock())
|
||||
service = RagPipelineDslService(session=cast(Session, session))
|
||||
session.query.return_value.filter_by.return_value.first.return_value = Mock()
|
||||
session.scalar.return_value = Mock()
|
||||
create_entity = RagPipelineDatasetCreateEntity(
|
||||
name="Existing Name",
|
||||
description="",
|
||||
@ -584,8 +587,8 @@ def test_create_rag_pipeline_dataset_raises_when_name_conflicts(mocker) -> None:
|
||||
def test_create_rag_pipeline_dataset_generates_name_when_missing(mocker) -> None:
|
||||
session = cast(MagicMock, Mock())
|
||||
service = RagPipelineDslService(session=cast(Session, session))
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
session.query.return_value.filter_by.return_value.all.return_value = [Mock(name="Untitled")]
|
||||
session.scalar.return_value = None
|
||||
session.scalars.return_value.all.return_value = [Mock(name="Untitled")]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.generate_incremental_name", return_value="Untitled 2")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", Mock(id="u1", current_tenant_id="t1"))
|
||||
mocker.patch.object(
|
||||
@ -632,7 +635,7 @@ def test_append_workflow_export_data_encrypts_knowledge_retrieval_dataset_ids(mo
|
||||
]
|
||||
}
|
||||
}
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
mocker.patch.object(service, "encrypt_dataset_id", side_effect=lambda dataset_id, tenant_id: f"enc-{dataset_id}")
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
|
||||
@ -727,7 +730,7 @@ def test_create_or_update_pipeline_decrypts_knowledge_retrieval_dataset_ids(mock
|
||||
},
|
||||
}
|
||||
draft_workflow = Mock(id="wf1")
|
||||
session.query.return_value.where.return_value.first.return_value = draft_workflow
|
||||
session.scalar.return_value = draft_workflow
|
||||
mocker.patch.object(service, "decrypt_dataset_id", side_effect=["d1", None])
|
||||
|
||||
result = service._create_or_update_pipeline(pipeline=pipeline, data=data, account=account)
|
||||
@ -743,7 +746,8 @@ def test_create_or_update_pipeline_creates_draft_when_missing(mocker) -> None:
|
||||
account = Mock(id="u1", current_tenant_id="t1")
|
||||
pipeline = Mock(id="p1", tenant_id="t1", name="N", description="D")
|
||||
data = {"rag_pipeline": {"name": "N2", "description": "D2"}, "workflow": {"graph": {"nodes": []}}}
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
workflow_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow")
|
||||
workflow_cls.return_value.id = "wf-new"
|
||||
|
||||
@ -817,7 +821,7 @@ def test_import_rag_pipeline_fails_for_non_string_version_type() -> None:
|
||||
def test_append_workflow_export_data_raises_when_draft_workflow_missing() -> None:
|
||||
session = cast(MagicMock, Mock())
|
||||
service = RagPipelineDslService(session=cast(Session, session))
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Missing draft workflow configuration"):
|
||||
service._append_workflow_export_data(export_data={}, pipeline=Mock(tenant_id="t1"), include_secret=False)
|
||||
@ -841,7 +845,7 @@ def test_append_workflow_export_data_keeps_secret_fields_when_include_secret_tru
|
||||
]
|
||||
}
|
||||
}
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
|
||||
return_value=[],
|
||||
@ -1003,7 +1007,8 @@ def test_import_rag_pipeline_sets_default_version_and_kind(mocker) -> None:
|
||||
)
|
||||
dataset = Mock(id="d1")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset)
|
||||
session.query.return_value.filter_by.return_value.all.return_value = []
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
session.scalars.return_value.all.return_value = []
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.generate_incremental_name", return_value="P")
|
||||
|
||||
result = service.import_rag_pipeline(
|
||||
@ -1061,7 +1066,7 @@ def test_append_workflow_export_data_skips_empty_node_data(mocker) -> None:
|
||||
workflow = Mock()
|
||||
workflow.graph_dict = {"nodes": []}
|
||||
workflow.to_dict.return_value = {"graph": {"nodes": [{"data": {}}, {}]}}
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
|
||||
return_value=[],
|
||||
@ -1246,11 +1251,12 @@ def test_create_or_update_pipeline_saves_dependencies_to_redis(mocker) -> None:
|
||||
account = Mock(id="u1", current_tenant_id="t1")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", SimpleNamespace(id="u1"))
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow", return_value=Mock(id="wf-1"))
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
pipeline_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Pipeline")
|
||||
pipeline = pipeline_cls.return_value
|
||||
pipeline.tenant_id = "t1"
|
||||
pipeline.id = "p1"
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
setex = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.redis_client.setex")
|
||||
dependency = PluginDependency(
|
||||
type=PluginDependency.Type.Marketplace,
|
||||
|
||||
@ -116,81 +116,6 @@ def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_serv
|
||||
assert has_more is True
|
||||
|
||||
|
||||
def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service) -> None:
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
rag_pipeline_service.get_pipeline("tenant-1", "dataset-1")
|
||||
|
||||
|
||||
# --- update_customized_pipeline_template ---
|
||||
|
||||
|
||||
def test_update_customized_pipeline_template_success(mocker) -> None:
|
||||
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
|
||||
|
||||
# First scalar finds the template, second scalar (duplicate check) returns None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, None])
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
info = PipelineTemplateInfoEntity(
|
||||
name="new",
|
||||
description="new desc",
|
||||
icon_info=IconInfo(icon="🔥"),
|
||||
)
|
||||
result = RagPipelineService.update_customized_pipeline_template("tpl-1", info)
|
||||
|
||||
assert result.name == "new"
|
||||
assert result.description == "new desc"
|
||||
|
||||
|
||||
def test_update_customized_pipeline_template_not_found(mocker) -> None:
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
info = PipelineTemplateInfoEntity(name="x", description="d", icon_info=IconInfo(icon="i"))
|
||||
with pytest.raises(ValueError, match="Customized pipeline template not found"):
|
||||
RagPipelineService.update_customized_pipeline_template("tpl-missing", info)
|
||||
|
||||
|
||||
def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
|
||||
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
|
||||
duplicate = SimpleNamespace(name="dup")
|
||||
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, duplicate])
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
info = PipelineTemplateInfoEntity(name="dup", description="d", icon_info=IconInfo(icon="i"))
|
||||
with pytest.raises(ValueError, match="Template name is already exists"):
|
||||
RagPipelineService.update_customized_pipeline_template("tpl-1", info)
|
||||
|
||||
|
||||
# --- delete_customized_pipeline_template ---
|
||||
|
||||
|
||||
def test_delete_customized_pipeline_template_success(mocker) -> None:
|
||||
template = SimpleNamespace(id="tpl-1")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
|
||||
delete_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.delete")
|
||||
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
RagPipelineService.delete_customized_pipeline_template("tpl-1")
|
||||
|
||||
delete_mock.assert_called_once_with(template)
|
||||
commit_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_customized_pipeline_template_not_found(mocker) -> None:
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
with pytest.raises(ValueError, match="Customized pipeline template not found"):
|
||||
RagPipelineService.delete_customized_pipeline_template("tpl-missing")
|
||||
|
||||
|
||||
# --- sync_draft_workflow ---
|
||||
|
||||
|
||||
|
||||
@ -17,8 +17,7 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
def mock_session(self):
|
||||
"""Create a mock database session."""
|
||||
session = Mock(spec=Session)
|
||||
session.query.return_value.filter.return_value.all.return_value = []
|
||||
session.query.return_value.filter.return_value.delete.return_value = 0
|
||||
session.scalars.return_value.all.return_value = []
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
@ -54,18 +53,18 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", [])
|
||||
|
||||
# Should not call any database operations
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.scalars.assert_not_called()
|
||||
mock_storage.save.assert_not_called()
|
||||
|
||||
def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids):
|
||||
"""Test when no related records are found."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.where.return_value.all.return_value = []
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should call query for each related table but find no records
|
||||
assert mock_session.query.call_count > 0
|
||||
# Should call scalars for each related table but find no records
|
||||
assert mock_session.scalars.call_count > 0
|
||||
mock_storage.save.assert_not_called()
|
||||
|
||||
def test_clear_message_related_tables_with_records_and_to_dict(
|
||||
@ -73,7 +72,7 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
):
|
||||
"""Test when records are found and have to_dict method."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.where.return_value.all.return_value = sample_records
|
||||
mock_session.scalars.return_value.all.return_value = sample_records
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
@ -104,7 +103,7 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
records.append(record)
|
||||
|
||||
# Mock records for first table only, empty for others
|
||||
mock_session.query.return_value.where.return_value.all.side_effect = [
|
||||
mock_session.scalars.return_value.all.side_effect = [
|
||||
records,
|
||||
[],
|
||||
[],
|
||||
@ -126,13 +125,13 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_storage.save.side_effect = Exception("Storage error")
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = sample_records
|
||||
mock_session.scalars.return_value.all.return_value = sample_records
|
||||
|
||||
# Should not raise exception
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should still delete records even if backup fails
|
||||
assert mock_session.query.return_value.where.return_value.delete.called
|
||||
assert mock_session.execute.called
|
||||
|
||||
def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids):
|
||||
"""Test that method continues even when record serialization fails."""
|
||||
@ -141,23 +140,23 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
record.id = "record-1"
|
||||
record.to_dict.side_effect = Exception("Serialization error")
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [record]
|
||||
mock_session.scalars.return_value.all.return_value = [record]
|
||||
|
||||
# Should not raise exception
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should still delete records even if serialization fails
|
||||
assert mock_session.query.return_value.where.return_value.delete.called
|
||||
assert mock_session.execute.called
|
||||
|
||||
def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records):
|
||||
"""Test that deletion is called for found records."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.where.return_value.all.return_value = sample_records
|
||||
mock_session.scalars.return_value.all.return_value = sample_records
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should call delete for each table that has records
|
||||
assert mock_session.query.return_value.where.return_value.delete.called
|
||||
# Should call execute(delete(...)) for each table that has records
|
||||
assert mock_session.execute.called
|
||||
|
||||
def test_clear_message_related_tables_all_serialization_fails_skips_backup_but_deletes(
|
||||
self, mock_session, sample_message_ids
|
||||
@ -167,12 +166,12 @@ class TestClearFreePlanTenantExpiredLogs:
|
||||
record.to_dict.side_effect = Exception("Serialization error")
|
||||
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [record]
|
||||
mock_session.scalars.return_value.all.return_value = [record]
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
mock_storage.save.assert_not_called()
|
||||
assert mock_session.query.return_value.where.return_value.delete.called
|
||||
assert mock_session.execute.called
|
||||
|
||||
|
||||
class _ImmediateFuture:
|
||||
@ -263,42 +262,23 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -
|
||||
conv1 = SimpleNamespace(id="c1", to_dict=lambda: {"id": "c1"})
|
||||
log1 = SimpleNamespace(id="l1", to_dict=lambda: {"id": "l1"})
|
||||
|
||||
def make_query_with_batches(batches: list[list[object]]):
|
||||
q = MagicMock()
|
||||
q.where.return_value = q
|
||||
q.limit.return_value = q
|
||||
q.all.side_effect = batches
|
||||
q.delete.return_value = 1
|
||||
return q
|
||||
|
||||
msg_session_1 = MagicMock()
|
||||
msg_session_1.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock()
|
||||
)
|
||||
msg_session_1.scalars.return_value.all.return_value = [msg1]
|
||||
|
||||
msg_session_2 = MagicMock()
|
||||
msg_session_2.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[]]) if model == service_module.Message else MagicMock()
|
||||
)
|
||||
msg_session_2.scalars.return_value.all.return_value = []
|
||||
|
||||
conv_session_1 = MagicMock()
|
||||
conv_session_1.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock()
|
||||
)
|
||||
conv_session_1.scalars.return_value.all.return_value = [conv1]
|
||||
|
||||
conv_session_2 = MagicMock()
|
||||
conv_session_2.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock()
|
||||
)
|
||||
conv_session_2.scalars.return_value.all.return_value = []
|
||||
|
||||
wal_session_1 = MagicMock()
|
||||
wal_session_1.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock()
|
||||
)
|
||||
wal_session_1.scalars.return_value.all.return_value = [log1]
|
||||
|
||||
wal_session_2 = MagicMock()
|
||||
wal_session_2.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock()
|
||||
)
|
||||
wal_session_2.scalars.return_value.all.return_value = []
|
||||
|
||||
session_wrappers = [
|
||||
_sessionmaker_wrapper_for_begin(msg_session_1),
|
||||
@ -354,9 +334,7 @@ def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: py
|
||||
|
||||
# Total tenant count query
|
||||
count_session = MagicMock()
|
||||
count_query = MagicMock()
|
||||
count_query.count.return_value = 2
|
||||
count_session.query.return_value = count_query
|
||||
count_session.scalar.return_value = 2
|
||||
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
|
||||
|
||||
@ -421,32 +399,15 @@ def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pyt
|
||||
|
||||
# Sessions used:
|
||||
# 1) total tenant count
|
||||
# 2) per-batch tenant scan (count + tenant list)
|
||||
# 2) per-batch tenant scan (interval counts + tenant list)
|
||||
total_session = MagicMock()
|
||||
total_query = MagicMock()
|
||||
total_query.count.return_value = 250
|
||||
total_session.query.return_value = total_query
|
||||
|
||||
batch_session = MagicMock()
|
||||
q1 = MagicMock()
|
||||
q1.where.return_value = q1
|
||||
q1.count.return_value = 200
|
||||
q2 = MagicMock()
|
||||
q2.where.return_value = q2
|
||||
q2.count.return_value = 200
|
||||
q3 = MagicMock()
|
||||
q3.where.return_value = q3
|
||||
q3.count.return_value = 200
|
||||
q4 = MagicMock()
|
||||
q4.where.return_value = q4
|
||||
q4.count.return_value = 50 # choose this interval, then scale it
|
||||
total_session.scalar.return_value = 250
|
||||
|
||||
rows = [SimpleNamespace(id="tenant-a"), SimpleNamespace(id="tenant-b")]
|
||||
q_rs = MagicMock()
|
||||
q_rs.where.return_value = q_rs
|
||||
q_rs.order_by.return_value = rows
|
||||
|
||||
batch_session.query.side_effect = [q1, q2, q3, q4, q_rs]
|
||||
batch_session = MagicMock()
|
||||
# 4 test intervals queried: 200, 200, 200, 50 — breaks on 50 <= 100 (4th interval = 3h)
|
||||
batch_session.scalar.side_effect = [200, 200, 200, 50]
|
||||
batch_session.execute.return_value = rows
|
||||
|
||||
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
|
||||
@ -464,9 +425,7 @@ def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.Mo
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
count_session = MagicMock()
|
||||
count_query = MagicMock()
|
||||
count_query.count.return_value = 100
|
||||
count_session.query.return_value = count_query
|
||||
count_session.scalar.return_value = 100
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
|
||||
|
||||
flask_app = service_module.Flask("test-app")
|
||||
@ -513,25 +472,13 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon
|
||||
monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None)
|
||||
|
||||
total_session = MagicMock()
|
||||
total_query = MagicMock()
|
||||
total_query.count.return_value = 250
|
||||
total_session.query.return_value = total_query
|
||||
|
||||
batch_session = MagicMock()
|
||||
# Count results for all 5 intervals, all > 100 => take the for-else path.
|
||||
count_queries = []
|
||||
for _ in range(5):
|
||||
q = MagicMock()
|
||||
q.where.return_value = q
|
||||
q.count.return_value = 200
|
||||
count_queries.append(q)
|
||||
total_session.scalar.return_value = 250
|
||||
|
||||
rows = [SimpleNamespace(id="tenant-a")]
|
||||
q_rs = MagicMock()
|
||||
q_rs.where.return_value = q_rs
|
||||
q_rs.order_by.return_value = rows
|
||||
|
||||
batch_session.query.side_effect = [*count_queries, q_rs]
|
||||
batch_session = MagicMock()
|
||||
# All 5 intervals have > 100 tenants => for-else falls through to min interval (1h)
|
||||
batch_session.scalar.side_effect = [200, 200, 200, 200, 200]
|
||||
batch_session.execute.return_value = rows
|
||||
|
||||
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
|
||||
@ -542,8 +489,7 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon
|
||||
ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[])
|
||||
|
||||
assert process_tenant_mock.call_count == 1
|
||||
assert len(count_queries) == 5
|
||||
assert batch_session.query.call_count >= 6
|
||||
assert batch_session.scalar.call_count == 5
|
||||
|
||||
|
||||
def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -565,11 +511,7 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte
|
||||
|
||||
# Make message/conversation/workflow_app_log loops no-op (empty immediately)
|
||||
empty_session = MagicMock()
|
||||
q_empty = MagicMock()
|
||||
q_empty.where.return_value = q_empty
|
||||
q_empty.limit.return_value = q_empty
|
||||
q_empty.all.return_value = []
|
||||
empty_session.query.return_value = q_empty
|
||||
empty_session.scalars.return_value.all.return_value = []
|
||||
session_wrappers = [
|
||||
_sessionmaker_wrapper_for_begin(empty_session),
|
||||
_sessionmaker_wrapper_for_begin(empty_session),
|
||||
|
||||
@ -577,7 +577,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
def test_update_external_knowledge_binding_updates_changed_binding_values(self):
|
||||
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
|
||||
session = MagicMock()
|
||||
session.query.return_value.filter_by.return_value.first.return_value = binding
|
||||
session.scalar.return_value = binding
|
||||
session.add = MagicMock()
|
||||
session_context = _make_session_context(session)
|
||||
|
||||
@ -596,7 +596,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
|
||||
def test_update_external_knowledge_binding_raises_for_missing_binding(self):
|
||||
session = MagicMock()
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
session_context = _make_session_context(session)
|
||||
|
||||
mock_sessionmaker = MagicMock()
|
||||
|
||||
@ -129,7 +129,7 @@ class TestDocumentServiceQueryAndDownloadHelpers:
|
||||
|
||||
def test_update_documents_need_summary_updates_matching_documents_and_commits(self):
|
||||
session = MagicMock()
|
||||
session.query.return_value.filter.return_value.update.return_value = 2
|
||||
session.execute.return_value.rowcount = 2
|
||||
|
||||
with patch("services.dataset_service.session_factory") as session_factory_mock:
|
||||
session_factory_mock.create_session.return_value = _make_session_context(session)
|
||||
@ -1069,6 +1069,33 @@ class TestDocumentServiceCreateValidation:
|
||||
assert len(knowledge_config.process_rule.rules.pre_processing_rules) == 1
|
||||
assert knowledge_config.process_rule.rules.pre_processing_rules[0].enabled is False
|
||||
|
||||
def test_process_rule_args_validate_hierarchical_defaults_parent_mode_to_paragraph(self):
|
||||
knowledge_config = KnowledgeConfig(
|
||||
indexing_technique="economy",
|
||||
data_source=DataSource(
|
||||
info_list=InfoList(
|
||||
data_source_type="upload_file",
|
||||
file_info_list=FileInfo(file_ids=["file-1"]),
|
||||
)
|
||||
),
|
||||
process_rule=ProcessRule(
|
||||
mode="hierarchical",
|
||||
rules=Rule(
|
||||
pre_processing_rules=[
|
||||
PreProcessingRule(id="remove_extra_spaces", enabled=True),
|
||||
],
|
||||
segmentation=Segmentation(separator="\n", max_tokens=1024),
|
||||
subchunk_segmentation=Segmentation(separator="\n", max_tokens=512),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
DocumentService.process_rule_args_validate(knowledge_config)
|
||||
|
||||
assert knowledge_config.process_rule is not None
|
||||
assert knowledge_config.process_rule.rules is not None
|
||||
assert knowledge_config.process_rule.rules.parent_mode == "paragraph"
|
||||
|
||||
|
||||
class TestDocumentServiceSaveDocumentWithDatasetId:
|
||||
"""Unit tests for non-SQL validation branches in save_document_with_dataset_id."""
|
||||
|
||||
@ -179,11 +179,11 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_return_true_when_system_oauth_params_exist(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = MagicMock()
|
||||
mock_db_session.scalar.return_value = MagicMock()
|
||||
assert service.is_system_oauth_params_exist(make_id()) is True
|
||||
|
||||
def test_should_return_false_when_system_oauth_params_missing(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
assert service.is_system_oauth_params_exist(make_id()) is False
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
@ -205,7 +205,7 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_delete_tenant_config_when_removing_oauth_params(self, service, mock_db_session):
|
||||
service.remove_oauth_custom_client_params("t1", make_id())
|
||||
mock_db_session.query().delete.assert_called_once()
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# setup_oauth_custom_client_params (315-351)
|
||||
@ -217,14 +217,14 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.add.assert_not_called()
|
||||
|
||||
def test_should_create_new_config_when_none_exists(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
||||
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, True)
|
||||
mock_db_session.add.assert_called_once()
|
||||
|
||||
def test_should_update_existing_config_when_record_found(self, service, mock_db_session):
|
||||
existing = MagicMock()
|
||||
mock_db_session.query().first.return_value = existing
|
||||
mock_db_session.scalar.return_value = existing
|
||||
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
||||
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, False)
|
||||
mock_db_session.add.assert_not_called() # update in place, no add
|
||||
@ -255,7 +255,7 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_return_empty_dict_when_credential_not_found(self, service, mock_db_session, mock_user):
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
assert service.get_datasource_credentials("t1", "prov", "org/plug") == {}
|
||||
|
||||
def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user):
|
||||
@ -264,7 +264,7 @@ class TestDatasourceProviderService:
|
||||
p.auth_type = "oauth2"
|
||||
p.expires_at = 0 # expired
|
||||
p.encrypted_credentials = {"tok": "x"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.scalar.return_value = p
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
|
||||
@ -278,7 +278,7 @@ class TestDatasourceProviderService:
|
||||
p.auth_type = "api_key"
|
||||
p.expires_at = -1 # sentinel: never expires
|
||||
p.encrypted_credentials = {"k": "v"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.scalar.return_value = p
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "plain"}),
|
||||
@ -292,7 +292,7 @@ class TestDatasourceProviderService:
|
||||
p.auth_type = "api_key"
|
||||
p.expires_at = -1
|
||||
p.encrypted_credentials = {}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.scalar.return_value = p
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "v"}),
|
||||
@ -306,7 +306,7 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_return_empty_list_when_no_provider_credentials_exist(self, service, mock_db_session, mock_user):
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
mock_db_session.query().all.return_value = []
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
assert service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") == []
|
||||
|
||||
def test_should_refresh_and_return_credentials_when_oauth_expired(self, service, mock_db_session, mock_user):
|
||||
@ -314,7 +314,7 @@ class TestDatasourceProviderService:
|
||||
p.auth_type = "oauth2"
|
||||
p.expires_at = 0
|
||||
p.encrypted_credentials = {"t": "x"}
|
||||
mock_db_session.query().all.return_value = [p]
|
||||
mock_db_session.scalars.return_value.all.return_value = [p]
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
|
||||
@ -328,22 +328,21 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_provider_not_found_on_name_update(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.update_datasource_provider_name("t1", make_id(), "new", "cred-id")
|
||||
|
||||
def test_should_return_early_when_new_name_matches_current(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.name = "same"
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.scalar.return_value = p
|
||||
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
|
||||
|
||||
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.name = "old_name"
|
||||
p.is_default = False
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 1 # conflict
|
||||
mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
|
||||
|
||||
@ -351,8 +350,7 @@ class TestDatasourceProviderService:
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.name = "old_name"
|
||||
p.is_default = False
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
|
||||
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
|
||||
assert p.name == "new_name"
|
||||
|
||||
@ -361,7 +359,7 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_target_provider_not_found(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.set_default_datasource_provider("t1", make_id(), "bad-id")
|
||||
|
||||
@ -369,7 +367,7 @@ class TestDatasourceProviderService:
|
||||
target = MagicMock(spec=DatasourceProvider)
|
||||
target.provider = "provider"
|
||||
target.plugin_id = "org/plug"
|
||||
mock_db_session.query().first.return_value = target
|
||||
mock_db_session.scalar.return_value = target
|
||||
service.set_default_datasource_provider("t1", make_id(), "new-id")
|
||||
assert target.is_default is True
|
||||
|
||||
@ -428,13 +426,13 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_use_tenant_config_when_available(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = MagicMock(client_params={"k": "v"})
|
||||
mock_db_session.scalar.return_value = MagicMock(client_params={"k": "v"})
|
||||
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
||||
result = service.get_oauth_client("t1", make_id())
|
||||
assert result == {"k": "dec"}
|
||||
|
||||
def test_should_fallback_to_system_credentials_when_tenant_config_missing(self, service, mock_db_session):
|
||||
mock_db_session.query().first.side_effect = [None, MagicMock(system_credentials={"k": "sys"})]
|
||||
mock_db_session.scalar.side_effect = [None, MagicMock(system_credentials={"k": "sys"})]
|
||||
with (
|
||||
patch.object(service.provider_manager, "fetch_datasource_provider"),
|
||||
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=True),
|
||||
@ -444,7 +442,7 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_raise_value_error_when_no_oauth_config_available(self, service, mock_db_session):
|
||||
"""Neither tenant nor system credentials → raises ValueError."""
|
||||
mock_db_session.query().first.side_effect = [None, None]
|
||||
mock_db_session.scalar.side_effect = [None, None]
|
||||
with (
|
||||
patch.object(service.provider_manager, "fetch_datasource_provider"),
|
||||
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=False),
|
||||
@ -457,15 +455,14 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_add_oauth_provider_successfully_when_name_is_unique(self, service, mock_db_session):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
|
||||
mock_db_session.add.assert_called_once()
|
||||
|
||||
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
|
||||
"""Conflict on name results in auto-incremented name, not an error."""
|
||||
mock_db_session.query().count.return_value = 1 # conflict first, then auto-named
|
||||
mock_db_session.query().all.return_value = []
|
||||
mock_db_session.scalar.return_value = 1 # conflict first, then auto-named
|
||||
with (
|
||||
patch.object(service, "extract_secret_variables", return_value=[]),
|
||||
patch.object(service, "generate_next_datasource_provider_name", return_value="new_gen"),
|
||||
@ -475,8 +472,7 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_auto_generate_name_when_none_provided_for_oauth(self, service, mock_db_session):
|
||||
"""name=None causes auto-generation via generate_next_datasource_provider_name."""
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.query().all.return_value = []
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with (
|
||||
patch.object(service, "extract_secret_variables", return_value=[]),
|
||||
patch.object(service, "generate_next_datasource_provider_name", return_value="auto"),
|
||||
@ -485,13 +481,13 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.add.assert_called_once()
|
||||
|
||||
def test_should_encrypt_secret_fields_when_adding_oauth_provider(self, service, mock_db_session):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=["secret_key"]):
|
||||
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {"secret_key": "value"})
|
||||
self._enc.encrypt_token.assert_called()
|
||||
|
||||
def test_should_acquire_redis_lock_when_adding_oauth_provider(self, service, mock_db_session):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {})
|
||||
self._redis.lock.assert_called()
|
||||
@ -501,23 +497,21 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_credential_id_not_found_on_reauth(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "bad-id")
|
||||
|
||||
def test_should_reauthorize_and_commit_when_credential_found(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
|
||||
|
||||
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 1 # conflict
|
||||
mock_db_session.query().all.return_value = []
|
||||
mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
|
||||
service.reauthorize_datasource_oauth_provider(
|
||||
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
|
||||
@ -525,16 +519,14 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
|
||||
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
|
||||
service.reauthorize_datasource_oauth_provider(None, "t1", make_id(), "u", 9999, {"tok": "val"}, "cred-id")
|
||||
self._enc.encrypt_token.assert_called()
|
||||
|
||||
def test_should_acquire_redis_lock_when_reauthorizing(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
|
||||
self._redis.lock.assert_called()
|
||||
@ -545,13 +537,13 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_raise_value_error_when_api_key_name_already_exists(self, service, mock_db_session, mock_user):
|
||||
"""explicit name supplied + conflict → raises ValueError immediately."""
|
||||
mock_db_session.query().count.return_value = 1
|
||||
mock_db_session.scalar.return_value = 1
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
service.add_datasource_api_key_provider("clash", "t1", make_id(), {"sk": "v"})
|
||||
|
||||
def test_should_raise_value_error_when_credentials_validation_fails(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad cred")),
|
||||
@ -561,7 +553,7 @@ class TestDatasourceProviderService:
|
||||
service.add_datasource_api_key_provider("nm", "t1", make_id(), {"k": "v"})
|
||||
|
||||
def test_should_add_api_key_provider_and_commit_when_valid(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service.provider_manager, "validate_provider_credentials"),
|
||||
@ -571,7 +563,7 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.add.assert_called_once()
|
||||
|
||||
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service.provider_manager, "validate_provider_credentials"),
|
||||
@ -694,7 +686,7 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_credential_not_found_on_update(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "name")
|
||||
@ -704,8 +696,7 @@ class TestDatasourceProviderService:
|
||||
p.name = "old_name"
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "e"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 1
|
||||
mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "new_name")
|
||||
@ -717,8 +708,7 @@ class TestDatasourceProviderService:
|
||||
p.name = "old_name"
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "e"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "extract_secret_variables", return_value=["sk"]),
|
||||
@ -733,8 +723,7 @@ class TestDatasourceProviderService:
|
||||
p.name = "old_name"
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "old_enc"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "extract_secret_variables", return_value=["sk"]),
|
||||
|
||||
@ -124,10 +124,7 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes
|
||||
existing.disabled_by = "u"
|
||||
|
||||
session = MagicMock(name="session")
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = existing
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = existing
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(session))
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
@ -149,10 +146,7 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes
|
||||
|
||||
def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session = MagicMock(name="session")
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = None
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(session))
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
@ -234,10 +228,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat
|
||||
|
||||
# New session used after vectorization succeeds (record not found by id nor chunk_id).
|
||||
session = MagicMock(name="session")
|
||||
q1 = MagicMock()
|
||||
q1.filter_by.return_value = q1
|
||||
q1.first.side_effect = [None, None]
|
||||
session.query.return_value = q1
|
||||
session.scalar.side_effect = [None, None]
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(session))
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
@ -267,10 +258,7 @@ def test_vectorize_summary_final_failure_updates_error_status(monkeypatch: pytes
|
||||
|
||||
# error_session should find record and commit status update
|
||||
error_session = MagicMock(name="error_session")
|
||||
q = MagicMock()
|
||||
q.filter_by.return_value = q
|
||||
q.first.return_value = summary
|
||||
error_session.query.return_value = q
|
||||
error_session.scalar.return_value = summary
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(error_session))
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
@ -302,10 +290,7 @@ def test_batch_create_summary_records_creates_and_updates(monkeypatch: pytest.Mo
|
||||
existing.enabled = False
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = [existing]
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = [existing]
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -324,10 +309,7 @@ def test_update_summary_record_error_updates_when_exists(monkeypatch: pytest.Mon
|
||||
record = _summary_record()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -346,10 +328,7 @@ def test_generate_and_vectorize_summary_success(monkeypatch: pytest.MonkeyPatch)
|
||||
record = _summary_record(summary_content="")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -373,10 +352,7 @@ def test_generate_and_vectorize_summary_vectorize_failure_sets_error(monkeypatch
|
||||
record = _summary_record(summary_content="")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -415,10 +391,7 @@ def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch
|
||||
existing = _summary_record(summary_content="old", node_id="old-node")
|
||||
existing.id = "other-id"
|
||||
session = MagicMock(name="session")
|
||||
q = MagicMock()
|
||||
q.filter_by.return_value = q
|
||||
q.first.side_effect = [None, existing] # miss by id, hit by chunk_id
|
||||
session.query.return_value = q
|
||||
session.scalar.side_effect = [None, existing] # miss by id, hit by chunk_id
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -448,10 +421,7 @@ def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pyte
|
||||
|
||||
existing = _summary_record(summary_content="old", node_id="old-node")
|
||||
session = MagicMock(name="session")
|
||||
q = MagicMock()
|
||||
q.filter_by.return_value = q
|
||||
q.first.return_value = existing # hit by id
|
||||
session.query.return_value = q
|
||||
session.scalar.return_value = existing # hit by id
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -487,10 +457,7 @@ def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(mon
|
||||
return None
|
||||
|
||||
error_session = MagicMock()
|
||||
q = MagicMock()
|
||||
q.filter_by.return_value = q
|
||||
q.first.return_value = summary
|
||||
error_session.query.return_value = q
|
||||
error_session.scalar.return_value = summary
|
||||
|
||||
create_session_mock = MagicMock(side_effect=[_BadContext(), _SessionContext(error_session)])
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
@ -516,21 +483,17 @@ def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatc
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
q = MagicMock()
|
||||
q.filter_by.return_value = q
|
||||
q.first.side_effect = [None, None] # miss by id and chunk_id
|
||||
session.query.return_value = q
|
||||
session.scalar.side_effect = [None, None] # miss by id and chunk_id
|
||||
|
||||
error_session = MagicMock()
|
||||
eq = MagicMock()
|
||||
eq.filter_by.return_value = eq
|
||||
eq.first.return_value = summary
|
||||
error_session.query.return_value = eq
|
||||
error_session.scalar.return_value = summary
|
||||
|
||||
create_session_mock = MagicMock(side_effect=[_SessionContext(session), _SessionContext(error_session)])
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
|
||||
# Force the created record to be None so the "should not be None" guard triggers.
|
||||
# Also mock select() so SQLAlchemy doesn't validate the mocked DocumentSegmentSummary as a real column clause.
|
||||
monkeypatch.setattr(summary_module, "select", MagicMock(return_value=MagicMock()))
|
||||
monkeypatch.setattr(summary_module, "DocumentSegmentSummary", MagicMock(return_value=None))
|
||||
|
||||
with pytest.raises(RuntimeError, match="summary_record_in_session should not be None"):
|
||||
@ -554,10 +517,7 @@ def test_vectorize_summary_error_handler_tries_chunk_id_lookup_and_can_warn_not_
|
||||
)
|
||||
|
||||
error_session = MagicMock(name="error_session")
|
||||
q = MagicMock()
|
||||
q.filter_by.return_value = q
|
||||
q.first.side_effect = [None, None] # not found by id, not found by chunk_id
|
||||
error_session.query.return_value = q
|
||||
error_session.scalar.side_effect = [None, None] # not found by id, not found by chunk_id
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -577,10 +537,7 @@ def test_update_summary_record_error_warns_when_missing(monkeypatch: pytest.Monk
|
||||
segment = _segment()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = None
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -599,10 +556,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo
|
||||
segment = _segment()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = None
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -646,11 +600,7 @@ def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: py
|
||||
seg2.id = "seg-2"
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = [seg1, seg2]
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = [seg1, seg2]
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -678,11 +628,7 @@ def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch:
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = []
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = []
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -702,11 +648,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu
|
||||
seg = _segment()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = [seg]
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = [seg]
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -723,7 +665,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu
|
||||
segment_ids=[seg.id],
|
||||
only_parent_chunks=True,
|
||||
)
|
||||
query.filter.assert_called()
|
||||
session.scalars.assert_called()
|
||||
|
||||
|
||||
def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -732,11 +674,7 @@ def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch:
|
||||
summary2 = _summary_record(summary_content="s", node_id=None)
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = [summary1, summary2]
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = [summary1, summary2]
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -761,11 +699,7 @@ def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch:
|
||||
def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _dataset()
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = []
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = []
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -793,21 +727,8 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt
|
||||
segment.status = SegmentStatus.COMPLETED
|
||||
|
||||
session = MagicMock()
|
||||
summary_query = MagicMock()
|
||||
summary_query.filter_by.return_value = summary_query
|
||||
summary_query.filter.return_value = summary_query
|
||||
summary_query.all.return_value = [summary]
|
||||
|
||||
seg_query = MagicMock()
|
||||
seg_query.filter_by.return_value = seg_query
|
||||
seg_query.first.return_value = segment
|
||||
|
||||
def query_side_effect(model: object) -> MagicMock:
|
||||
if model is summary_module.DocumentSegmentSummary:
|
||||
return summary_query
|
||||
return seg_query
|
||||
|
||||
session.query.side_effect = query_side_effect
|
||||
session.scalars.return_value.all.return_value = [summary]
|
||||
session.scalar.return_value = segment
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -826,11 +747,7 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt
|
||||
def test_enable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _dataset()
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = []
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = []
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -860,21 +777,9 @@ def test_enable_summaries_for_segments_skips_segment_or_content_and_handles_vect
|
||||
good_segment.status = SegmentStatus.COMPLETED
|
||||
|
||||
session = MagicMock()
|
||||
summary_query = MagicMock()
|
||||
summary_query.filter_by.return_value = summary_query
|
||||
summary_query.filter.return_value = summary_query
|
||||
summary_query.all.return_value = [summary1, summary2, summary3]
|
||||
session.scalars.return_value.all.return_value = [summary1, summary2, summary3]
|
||||
session.scalar.side_effect = [bad_segment, good_segment, good_segment]
|
||||
|
||||
seg_query = MagicMock()
|
||||
seg_query.filter_by.return_value = seg_query
|
||||
seg_query.first.side_effect = [bad_segment, good_segment, good_segment]
|
||||
|
||||
def query_side_effect(model: object) -> MagicMock:
|
||||
if model is summary_module.DocumentSegmentSummary:
|
||||
return summary_query
|
||||
return seg_query
|
||||
|
||||
session.query.side_effect = query_side_effect
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -895,11 +800,7 @@ def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch:
|
||||
summary = _summary_record(summary_content="sum", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = [summary]
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = [summary]
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
@ -918,11 +819,7 @@ def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch:
|
||||
def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _dataset()
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = []
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = []
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -946,10 +843,7 @@ def test_update_summary_for_segment_empty_content_deletes_existing(monkeypatch:
|
||||
record = _summary_record(summary_content="old", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
@ -971,10 +865,7 @@ def test_update_summary_for_segment_empty_content_delete_vector_warns(monkeypatc
|
||||
record = _summary_record(summary_content="old", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -996,10 +887,7 @@ def test_update_summary_for_segment_empty_content_no_record_noop(monkeypatch: py
|
||||
segment = _segment()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = None
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1015,10 +903,7 @@ def test_update_summary_for_segment_updates_existing_and_vectorizes(monkeypatch:
|
||||
record = _summary_record(summary_content="old", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
@ -1044,10 +929,7 @@ def test_update_summary_for_segment_existing_vector_delete_warns(monkeypatch: py
|
||||
record = _summary_record(summary_content="old", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1073,10 +955,7 @@ def test_update_summary_for_segment_existing_vectorize_failure_returns_error_rec
|
||||
record = _summary_record(summary_content="old", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1095,10 +974,7 @@ def test_update_summary_for_segment_new_record_success(monkeypatch: pytest.Monke
|
||||
segment = _segment()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = None
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1122,10 +998,7 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk
|
||||
record = _summary_record(summary_content="old", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
session.flush.side_effect = RuntimeError("flush boom")
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -1143,25 +1016,9 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk
|
||||
def test_get_segment_summary_and_document_summaries(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
record = _summary_record(summary_content="sum", node_id="n1")
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = record
|
||||
session.scalars.return_value.all.return_value = [record]
|
||||
|
||||
q1 = MagicMock()
|
||||
q1.where.return_value = q1
|
||||
q1.first.return_value = record
|
||||
|
||||
q2 = MagicMock()
|
||||
q2.filter.return_value = q2
|
||||
q2.all.return_value = [record]
|
||||
|
||||
def query_side_effect(model: object) -> MagicMock:
|
||||
if model is summary_module.DocumentSegmentSummary:
|
||||
# first call used by get_segment_summary, second by get_document_summaries
|
||||
if not hasattr(query_side_effect, "_called"):
|
||||
query_side_effect._called = True # type: ignore[attr-defined]
|
||||
return q1
|
||||
return q2
|
||||
return MagicMock()
|
||||
|
||||
session.query.side_effect = query_side_effect
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1178,10 +1035,7 @@ def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> No
|
||||
record2 = _summary_record()
|
||||
record2.chunk_id = "seg-2"
|
||||
session = MagicMock()
|
||||
q = MagicMock()
|
||||
q.where.return_value = q
|
||||
q.all.return_value = [record1, record2]
|
||||
session.query.return_value = q
|
||||
session.scalars.return_value.all.return_value = [record1, record2]
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1194,10 +1048,7 @@ def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> No
|
||||
|
||||
def test_get_document_summary_index_status_no_segments_returns_none(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session = MagicMock()
|
||||
q = MagicMock()
|
||||
q.where.return_value = q
|
||||
q.all.return_value = []
|
||||
session.query.return_value = q
|
||||
session.scalars.return_value.all.return_value = []
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1212,10 +1063,7 @@ def test_get_documents_summary_index_status_empty_input(monkeypatch: pytest.Monk
|
||||
|
||||
def test_get_documents_summary_index_status_no_pending_sets_none(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session = MagicMock()
|
||||
q = MagicMock()
|
||||
q.where.return_value = q
|
||||
q.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")]
|
||||
session.query.return_value = q
|
||||
session.execute.return_value.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")]
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1237,10 +1085,7 @@ def test_update_summary_for_segment_creates_new_and_vectorize_fails_returns_erro
|
||||
segment = _segment()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = None
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -1267,10 +1112,7 @@ def test_get_segments_summaries_empty_list() -> None:
|
||||
def test_get_document_summary_index_status_and_documents_status(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
seg_row = SimpleNamespace(id="seg-1", document_id="doc-1")
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.all.return_value = [SimpleNamespace(id="seg-1")]
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = ["seg-1"] # get_document_summary_index_status returns IDs
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(session))
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
@ -1283,11 +1125,8 @@ def test_get_document_summary_index_status_and_documents_status(monkeypatch: pyt
|
||||
assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") == "SUMMARIZING"
|
||||
|
||||
# Multiple docs
|
||||
query2 = MagicMock()
|
||||
query2.where.return_value = query2
|
||||
query2.all.return_value = [seg_row]
|
||||
session2 = MagicMock()
|
||||
session2.query.return_value = query2
|
||||
session2.execute.return_value.all.return_value = [seg_row] # get_documents_summary_index_status uses execute
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
|
||||
@ -124,9 +124,7 @@ def test_list_trigger_provider_subscriptions_should_return_empty_list_when_no_su
|
||||
provider_id: TriggerProviderID,
|
||||
) -> None:
|
||||
# Arrange
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value.order_by.return_value.all.return_value = []
|
||||
mock_session.query.return_value = query
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id)
|
||||
@ -152,11 +150,8 @@ def test_list_trigger_provider_subscriptions_should_mask_fields_and_attach_workf
|
||||
db_sub = SimpleNamespace(to_api_entity=lambda: api_sub)
|
||||
usage_row = SimpleNamespace(subscription_id="sub-1", app_count=2)
|
||||
|
||||
query_subs = MagicMock()
|
||||
query_subs.filter_by.return_value.order_by.return_value.all.return_value = [db_sub]
|
||||
query_usage = MagicMock()
|
||||
query_usage.filter.return_value.group_by.return_value.all.return_value = [usage_row]
|
||||
mock_session.query.side_effect = [query_subs, query_usage]
|
||||
mock_session.scalars.return_value.all.return_value = [db_sub]
|
||||
mock_session.execute.return_value.all.return_value = [usage_row]
|
||||
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cred_enc = _encrypter_mock(decrypted={"token": "plain"}, masked={"token": "****"})
|
||||
@ -188,11 +183,7 @@ def test_add_trigger_subscription_should_create_subscription_successfully_for_ap
|
||||
) -> None:
|
||||
# Arrange
|
||||
_patch_redis_lock(mocker)
|
||||
query_count = MagicMock()
|
||||
query_count.filter_by.return_value.count.return_value = 0
|
||||
query_existing = MagicMock()
|
||||
query_existing.filter_by.return_value.first.return_value = None
|
||||
mock_session.query.side_effect = [query_count, query_existing]
|
||||
mock_session.scalar.side_effect = [0, None] # count=0, no existing name
|
||||
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cred_enc = _encrypter_mock(encrypted={"api_key": "enc"})
|
||||
@ -228,11 +219,7 @@ def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorize
|
||||
) -> None:
|
||||
# Arrange
|
||||
_patch_redis_lock(mocker)
|
||||
query_count = MagicMock()
|
||||
query_count.filter_by.return_value.count.return_value = 0
|
||||
query_existing = MagicMock()
|
||||
query_existing.filter_by.return_value.first.return_value = None
|
||||
mock_session.query.side_effect = [query_count, query_existing]
|
||||
mock_session.scalar.side_effect = [0, None] # count=0, no existing name
|
||||
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
prop_enc = _encrypter_mock(encrypted={"p": "enc"})
|
||||
@ -267,9 +254,7 @@ def test_add_trigger_subscription_should_raise_error_when_provider_limit_reached
|
||||
) -> None:
|
||||
# Arrange
|
||||
_patch_redis_lock(mocker)
|
||||
query_count = MagicMock()
|
||||
query_count.filter_by.return_value.count.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__
|
||||
mock_session.query.return_value = query_count
|
||||
mock_session.scalar.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mock_logger = mocker.patch("services.trigger.trigger_provider_service.logger")
|
||||
|
||||
@ -297,11 +282,7 @@ def test_add_trigger_subscription_should_raise_error_when_name_exists(
|
||||
) -> None:
|
||||
# Arrange
|
||||
_patch_redis_lock(mocker)
|
||||
query_count = MagicMock()
|
||||
query_count.filter_by.return_value.count.return_value = 0
|
||||
query_existing = MagicMock()
|
||||
query_existing.filter_by.return_value.first.return_value = object()
|
||||
mock_session.query.side_effect = [query_count, query_existing]
|
||||
mock_session.scalar.side_effect = [0, object()] # count=0, existing name conflict
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
|
||||
# Act + Assert
|
||||
@ -325,9 +306,7 @@ def test_update_trigger_subscription_should_raise_error_when_subscription_not_fo
|
||||
) -> None:
|
||||
# Arrange
|
||||
_patch_redis_lock(mocker)
|
||||
query_sub = MagicMock()
|
||||
query_sub.filter_by.return_value.first.return_value = None
|
||||
mock_session.query.return_value = query_sub
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
@ -347,11 +326,7 @@ def test_update_trigger_subscription_should_raise_error_when_name_conflicts(
|
||||
provider_id="langgenius/github/github",
|
||||
credential_type=CredentialType.API_KEY.value,
|
||||
)
|
||||
query_sub = MagicMock()
|
||||
query_sub.filter_by.return_value.first.return_value = subscription
|
||||
query_existing = MagicMock()
|
||||
query_existing.filter_by.return_value.first.return_value = object()
|
||||
mock_session.query.side_effect = [query_sub, query_existing]
|
||||
mock_session.scalar.side_effect = [subscription, object()] # found sub, name conflict
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
|
||||
# Act + Assert
|
||||
@ -378,11 +353,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache(
|
||||
credential_expires_at=0,
|
||||
expires_at=0,
|
||||
)
|
||||
query_sub = MagicMock()
|
||||
query_sub.filter_by.return_value.first.return_value = subscription
|
||||
query_existing = MagicMock()
|
||||
query_existing.filter_by.return_value.first.return_value = None
|
||||
mock_session.query.side_effect = [query_sub, query_existing]
|
||||
mock_session.scalar.side_effect = [subscription, None] # found sub, no name conflict
|
||||
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
prop_enc = _encrypter_mock(decrypted={"project": "old-value"}, encrypted={"project": "new-value"})
|
||||
@ -417,7 +388,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache(
|
||||
|
||||
def test_get_subscription_by_id_should_return_none_when_missing(mocker: MockerFixture, mock_session: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1")
|
||||
@ -439,7 +410,7 @@ def test_get_subscription_by_id_should_decrypt_credentials_and_properties(
|
||||
credentials={"token": "enc"},
|
||||
properties={"project": "enc"},
|
||||
)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cred_enc = _encrypter_mock(decrypted={"token": "plain"})
|
||||
prop_enc = _encrypter_mock(decrypted={"project": "plain"})
|
||||
@ -466,7 +437,7 @@ def test_delete_trigger_provider_should_raise_error_when_subscription_missing(
|
||||
mock_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
@ -488,7 +459,7 @@ def test_delete_trigger_provider_should_delete_and_clear_cache_even_if_unsubscri
|
||||
credentials={"token": "enc"},
|
||||
to_entity=lambda: SimpleNamespace(id="sub-1"),
|
||||
)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cred_enc = _encrypter_mock(decrypted={"token": "plain"})
|
||||
mocker.patch(
|
||||
@ -524,7 +495,7 @@ def test_delete_trigger_provider_should_skip_unsubscribe_for_unauthorized(
|
||||
credentials={},
|
||||
to_entity=lambda: SimpleNamespace(id="sub-2"),
|
||||
)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mock_unsubscribe = mocker.patch("services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger")
|
||||
mocker.patch(
|
||||
@ -544,7 +515,7 @@ def test_refresh_oauth_token_should_raise_error_when_subscription_missing(
|
||||
mocker: MockerFixture, mock_session: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
@ -556,7 +527,7 @@ def test_refresh_oauth_token_should_raise_error_for_non_oauth_credentials(
|
||||
) -> None:
|
||||
# Arrange
|
||||
subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Only OAuth credentials can be refreshed"):
|
||||
@ -577,7 +548,7 @@ def test_refresh_oauth_token_should_refresh_and_persist_new_credentials(
|
||||
credentials={"access_token": "enc"},
|
||||
credential_expires_at=0,
|
||||
)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cache = MagicMock()
|
||||
cred_enc = _encrypter_mock(decrypted={"access_token": "old"}, encrypted={"access_token": "new"})
|
||||
@ -606,7 +577,7 @@ def test_refresh_subscription_should_raise_error_when_subscription_missing(
|
||||
mocker: MockerFixture, mock_session: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
@ -616,7 +587,7 @@ def test_refresh_subscription_should_raise_error_when_subscription_missing(
|
||||
def test_refresh_subscription_should_skip_when_not_due(mocker: MockerFixture, mock_session: MagicMock) -> None:
|
||||
# Arrange
|
||||
subscription = SimpleNamespace(expires_at=200)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100)
|
||||
@ -643,7 +614,7 @@ def test_refresh_subscription_should_refresh_and_persist_properties(
|
||||
credentials={"c": "enc"},
|
||||
credential_type=CredentialType.API_KEY.value,
|
||||
)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cred_enc = _encrypter_mock(decrypted={"c": "plain"})
|
||||
prop_cache = MagicMock()
|
||||
@ -681,10 +652,7 @@ def test_get_oauth_client_should_return_tenant_client_when_available(
|
||||
) -> None:
|
||||
# Arrange
|
||||
tenant_client = SimpleNamespace(oauth_params={"client_id": "enc"})
|
||||
system_client = None
|
||||
query_tenant = MagicMock()
|
||||
query_tenant.filter_by.return_value.first.return_value = tenant_client
|
||||
mock_session.query.return_value = query_tenant
|
||||
mock_session.scalar.return_value = tenant_client
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
enc = _encrypter_mock(decrypted={"client_id": "plain"})
|
||||
mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock()))
|
||||
@ -703,11 +671,7 @@ def test_get_oauth_client_should_return_none_when_plugin_not_verified(
|
||||
provider_controller: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
query_tenant = MagicMock()
|
||||
query_tenant.filter_by.return_value.first.return_value = None
|
||||
query_system = MagicMock()
|
||||
query_system.filter_by.return_value.first.return_value = None
|
||||
mock_session.query.side_effect = [query_tenant, query_system]
|
||||
mock_session.scalar.return_value = None # no tenant client; plugin not verified → early return
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False)
|
||||
|
||||
@ -725,11 +689,7 @@ def test_get_oauth_client_should_return_decrypted_system_client_when_verified(
|
||||
provider_controller: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
query_tenant = MagicMock()
|
||||
query_tenant.filter_by.return_value.first.return_value = None
|
||||
query_system = MagicMock()
|
||||
query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc")
|
||||
mock_session.query.side_effect = [query_tenant, query_system]
|
||||
mock_session.scalar.side_effect = [None, SimpleNamespace(encrypted_oauth_params="enc")]
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
|
||||
mocker.patch(
|
||||
@ -751,11 +711,7 @@ def test_get_oauth_client_should_raise_error_when_system_decryption_fails(
|
||||
provider_controller: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
query_tenant = MagicMock()
|
||||
query_tenant.filter_by.return_value.first.return_value = None
|
||||
query_system = MagicMock()
|
||||
query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc")
|
||||
mock_session.query.side_effect = [query_tenant, query_system]
|
||||
mock_session.scalar.side_effect = [None, SimpleNamespace(encrypted_oauth_params="enc")]
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
|
||||
mocker.patch(
|
||||
@ -794,7 +750,7 @@ def test_is_oauth_system_client_exists_should_reflect_database_record(
|
||||
provider_controller: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = object() if has_client else None
|
||||
mock_session.scalar.return_value = object() if has_client else None
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
|
||||
|
||||
@ -823,11 +779,11 @@ def test_save_custom_oauth_client_params_should_create_record_and_clear_params_w
|
||||
provider_controller: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value.first.return_value = None
|
||||
mock_session.query.return_value = query
|
||||
mock_session.scalar.return_value = None
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
fake_model = SimpleNamespace(encrypted_oauth_params="", enabled=False, oauth_params={})
|
||||
# Also mock select() so SQLAlchemy doesn't validate the patched TriggerOAuthTenantClient.
|
||||
mocker.patch("services.trigger.trigger_provider_service.select", MagicMock(return_value=MagicMock()))
|
||||
mocker.patch("services.trigger.trigger_provider_service.TriggerOAuthTenantClient", return_value=fake_model)
|
||||
|
||||
# Act
|
||||
@ -853,7 +809,7 @@ def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_c
|
||||
) -> None:
|
||||
# Arrange
|
||||
custom_client = SimpleNamespace(oauth_params={"client_id": "enc-old"}, enabled=False)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client
|
||||
mock_session.scalar.return_value = custom_client
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cache = MagicMock()
|
||||
enc = _encrypter_mock(decrypted={"client_id": "old-id"}, encrypted={"client_id": "new-id"})
|
||||
@ -882,7 +838,7 @@ def test_get_custom_oauth_client_params_should_return_empty_when_record_missing(
|
||||
provider_id: TriggerProviderID,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id)
|
||||
@ -899,7 +855,7 @@ def test_get_custom_oauth_client_params_should_return_masked_decrypted_values(
|
||||
) -> None:
|
||||
# Arrange
|
||||
custom_client = SimpleNamespace(oauth_params={"client_id": "enc"})
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client
|
||||
mock_session.scalar.return_value = custom_client
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
enc = _encrypter_mock(decrypted={"client_id": "plain"}, masked={"client_id": "pl***id"})
|
||||
mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock()))
|
||||
@ -916,9 +872,6 @@ def test_delete_custom_oauth_client_params_should_delete_record_and_commit(
|
||||
mock_session: MagicMock,
|
||||
provider_id: TriggerProviderID,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.delete.return_value = 1
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.delete_custom_oauth_client_params("tenant-1", provider_id)
|
||||
|
||||
@ -934,7 +887,7 @@ def test_is_oauth_custom_client_enabled_should_return_expected_boolean(
|
||||
provider_id: TriggerProviderID,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = object() if exists else None
|
||||
mock_session.scalar.return_value = object() if exists else None
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.is_oauth_custom_client_enabled("tenant-1", provider_id)
|
||||
@ -947,7 +900,7 @@ def test_get_subscription_by_endpoint_should_return_none_when_not_found(
|
||||
mocker: MockerFixture, mock_session: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1")
|
||||
@ -968,7 +921,7 @@ def test_get_subscription_by_endpoint_should_decrypt_credentials_and_properties(
|
||||
credentials={"token": "enc"},
|
||||
properties={"hook": "enc"},
|
||||
)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch(
|
||||
"services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription",
|
||||
|
||||
@ -969,8 +969,7 @@ class TestWorkflowService:
|
||||
# 1. Workflow exists
|
||||
# 2. No app is currently using it
|
||||
# 3. Not published as a tool
|
||||
mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None # no tool provider
|
||||
mock_session.scalar.side_effect = [mock_workflow, None, None] # workflow, no app using it, no tool provider
|
||||
|
||||
with patch("services.workflow_service.select") as mock_select:
|
||||
mock_stmt = MagicMock()
|
||||
@ -1045,8 +1044,7 @@ class TestWorkflowService:
|
||||
mock_tool_provider = MagicMock()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_tool_provider
|
||||
mock_session.scalar.side_effect = [mock_workflow, None, mock_tool_provider] # workflow, no app, tool provider
|
||||
|
||||
with patch("services.workflow_service.select") as mock_select:
|
||||
mock_stmt = MagicMock()
|
||||
|
||||
@ -32,7 +32,7 @@ class TestDeleteCustomOauthClientParams:
|
||||
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.query.return_value.filter_by.return_value.delete.assert_called_once()
|
||||
session.execute.assert_called_once()
|
||||
|
||||
|
||||
class TestListBuiltinToolProviderTools:
|
||||
@ -111,7 +111,7 @@ class TestIsOauthSystemClientExists:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_true_when_exists(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = MagicMock()
|
||||
session.scalar.return_value = MagicMock()
|
||||
|
||||
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is True
|
||||
|
||||
@ -119,7 +119,7 @@ class TestIsOauthSystemClientExists:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_false_when_missing(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is False
|
||||
|
||||
@ -129,7 +129,7 @@ class TestIsOauthCustomClientEnabled:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_true_when_enabled(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = MagicMock(enabled=True)
|
||||
session.scalar.return_value = MagicMock(enabled=True)
|
||||
|
||||
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is True
|
||||
|
||||
@ -137,7 +137,7 @@ class TestIsOauthCustomClientEnabled:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_false_when_none(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is False
|
||||
|
||||
@ -149,7 +149,7 @@ class TestDeleteBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "id")
|
||||
@ -161,7 +161,7 @@ class TestDeleteBuiltinToolProvider:
|
||||
def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
db_provider = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
mock_cache = MagicMock()
|
||||
mock_enc.return_value = (MagicMock(), mock_cache)
|
||||
|
||||
@ -177,7 +177,7 @@ class TestSetDefaultProvider:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_not_found(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="provider not found"):
|
||||
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
||||
@ -187,7 +187,7 @@ class TestSetDefaultProvider:
|
||||
def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
target = MagicMock()
|
||||
session.query.return_value.filter_by.return_value.first.return_value = target
|
||||
session.scalar.return_value = target
|
||||
|
||||
result = BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
||||
|
||||
@ -200,7 +200,7 @@ class TestUpdateBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c")
|
||||
@ -213,7 +213,7 @@ class TestUpdateBuiltinToolProvider:
|
||||
def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
db_provider = MagicMock(credential_type="api_key", credentials="{}")
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
|
||||
mock_cred_instance = MagicMock()
|
||||
mock_cred_instance.is_editable.return_value = True
|
||||
@ -274,7 +274,7 @@ class TestGetOauthClient:
|
||||
mock_create_enc.return_value = (mock_encrypter, MagicMock())
|
||||
|
||||
user_client = MagicMock(oauth_params='{"encrypted": "data"}')
|
||||
session.query.return_value.filter_by.return_value.first.return_value = user_client
|
||||
session.scalar.return_value = user_client
|
||||
|
||||
result = BuiltinToolManageService.get_oauth_client("t", "google")
|
||||
|
||||
@ -297,10 +297,7 @@ class TestGetOauthClient:
|
||||
mock_create_enc.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
system_client = MagicMock(encrypted_oauth_params="enc")
|
||||
session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
None, # user client
|
||||
system_client, # system client
|
||||
]
|
||||
session.scalar.side_effect = [None, system_client]
|
||||
|
||||
result = BuiltinToolManageService.get_oauth_client("t", "google")
|
||||
|
||||
@ -325,7 +322,7 @@ class TestGetCustomOauthClientParams:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_returns_empty_when_none(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
result = BuiltinToolManageService.get_custom_oauth_client_params("t", "p")
|
||||
|
||||
@ -391,7 +388,7 @@ class TestGetBuiltinProvider:
|
||||
session = _mock_session(mock_session_cls)
|
||||
mock_prov_id.return_value.provider_name = "google"
|
||||
mock_prov_id.return_value.organization = "langgenius"
|
||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_provider("google", "t")
|
||||
|
||||
@ -417,7 +414,7 @@ class TestGetBuiltinProvider:
|
||||
return m
|
||||
|
||||
mock_prov_id.side_effect = prov_id_side_effect
|
||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_provider("google", "t")
|
||||
|
||||
@ -439,7 +436,7 @@ class TestGetBuiltinProvider:
|
||||
|
||||
mock_prov_id.side_effect = prov_id_side_effect
|
||||
db_provider = MagicMock(provider="third-party/custom/custom-tool")
|
||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_provider("third-party/custom/custom-tool", "t")
|
||||
|
||||
@ -452,7 +449,7 @@ class TestGetBuiltinProvider:
|
||||
session = _mock_session(mock_session_cls)
|
||||
mock_prov_id.side_effect = Exception("parse error")
|
||||
fallback = MagicMock()
|
||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = fallback
|
||||
session.scalar.return_value = fallback
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_provider("old-provider", "t")
|
||||
|
||||
|
||||
@ -82,8 +82,8 @@ def mock_db_session():
|
||||
"""Mock session_factory.create_session() to return a session whose queries use shared test data.
|
||||
|
||||
Tests set session._shared_data = {"dataset": <Dataset>, "documents": [<Document>, ...]}
|
||||
This fixture makes session.query(Dataset).first() return the shared dataset,
|
||||
and session.query(Document).all()/first() return from the shared documents.
|
||||
This fixture makes session.scalar(select(Dataset)...) return the shared dataset,
|
||||
and session.scalars(select(Document)...).all() return the shared documents.
|
||||
"""
|
||||
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
@ -92,93 +92,68 @@ def mock_db_session():
|
||||
# Keep a pointer so repeated Document.first() calls iterate across provided docs
|
||||
session._doc_first_idx = 0
|
||||
|
||||
def _query_side_effect(model):
|
||||
q = MagicMock()
|
||||
def _get_entity(stmt) -> type | None:
|
||||
"""Extract the mapped entity class from a SQLAlchemy select statement."""
|
||||
try:
|
||||
descs = stmt.column_descriptions
|
||||
if descs:
|
||||
return descs[0].get("entity")
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
# Capture filters passed via where(...) so first()/all() can honor them.
|
||||
q._filters = {}
|
||||
def _extract_id_from_where(stmt) -> str | None:
|
||||
"""Return the value bound to the 'id' column in the WHERE clause, if present."""
|
||||
try:
|
||||
where = stmt.whereclause
|
||||
if where is None:
|
||||
return None
|
||||
# Both single-clause and AND-clause-list cases
|
||||
clauses = list(getattr(where, "clauses", [where]))
|
||||
for clause in clauses:
|
||||
left = getattr(clause, "left", None)
|
||||
right = getattr(clause, "right", None)
|
||||
if left is not None and right is not None:
|
||||
if getattr(left, "key", None) == "id":
|
||||
return getattr(right, "value", None)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _extract_filters(*conds, **kw):
|
||||
# Support both SQLAlchemy expressions (BinaryExpression) and kwargs
|
||||
# We only need the simple fields used by production code: id, dataset_id, and id.in_(...)
|
||||
for cond in conds:
|
||||
left = getattr(cond, "left", None)
|
||||
right = getattr(cond, "right", None)
|
||||
key = None
|
||||
if left is not None:
|
||||
key = getattr(left, "key", None) or getattr(left, "name", None)
|
||||
if not key:
|
||||
continue
|
||||
# Right side might be a BindParameter with .value, or a raw value/sequence
|
||||
val = getattr(right, "value", right)
|
||||
q._filters[key] = val
|
||||
# Also accept kwargs (e.g., where(id=...)) just in case
|
||||
for k, v in kw.items():
|
||||
q._filters[k] = v
|
||||
|
||||
def _where_side_effect(*conds, **kw):
|
||||
_extract_filters(*conds, **kw)
|
||||
return q
|
||||
|
||||
q.where.side_effect = _where_side_effect
|
||||
|
||||
# Dataset queries
|
||||
if model.__name__ == "Dataset":
|
||||
|
||||
def _dataset_first():
|
||||
ds = session._shared_data.get("dataset")
|
||||
if not ds:
|
||||
return None
|
||||
if "id" in q._filters:
|
||||
val = q._filters["id"]
|
||||
if isinstance(val, (list, tuple, set)):
|
||||
return ds if ds.id in val else None
|
||||
return ds if ds.id == val else None
|
||||
return ds
|
||||
|
||||
def _dataset_all():
|
||||
ds = session._shared_data.get("dataset")
|
||||
if not ds:
|
||||
return []
|
||||
first = _dataset_first()
|
||||
return [first] if first else []
|
||||
|
||||
q.first.side_effect = _dataset_first
|
||||
q.all.side_effect = _dataset_all
|
||||
return q
|
||||
|
||||
# Document queries
|
||||
if model.__name__ == "Document":
|
||||
|
||||
def _apply_doc_filters(docs):
|
||||
result = list(docs)
|
||||
for key in ("id", "dataset_id"):
|
||||
if key in q._filters:
|
||||
val = q._filters[key]
|
||||
if isinstance(val, (list, tuple, set)):
|
||||
result = [d for d in result if getattr(d, key, None) in val]
|
||||
else:
|
||||
result = [d for d in result if getattr(d, key, None) == val]
|
||||
return result
|
||||
|
||||
def _docs_all():
|
||||
def _scalar_side_effect(stmt):
|
||||
entity = _get_entity(stmt)
|
||||
if entity is not None:
|
||||
if entity.__name__ == "Dataset":
|
||||
return session._shared_data.get("dataset")
|
||||
elif entity.__name__ == "Document":
|
||||
docs = session._shared_data.get("documents", [])
|
||||
return _apply_doc_filters(docs)
|
||||
if not docs:
|
||||
return None
|
||||
# When the WHERE clause filters by id, return the matching document
|
||||
queried_id = _extract_id_from_where(stmt)
|
||||
if queried_id:
|
||||
doc_map = {d.id: d for d in docs}
|
||||
return doc_map.get(queried_id, docs[0])
|
||||
return docs[0]
|
||||
return None
|
||||
|
||||
def _docs_first():
|
||||
docs = _docs_all()
|
||||
return docs[0] if docs else None
|
||||
def _scalars_side_effect(stmt):
|
||||
entity = _get_entity(stmt)
|
||||
result = MagicMock()
|
||||
if entity is not None:
|
||||
if entity.__name__ == "Document":
|
||||
result.all.return_value = list(session._shared_data.get("documents", []))
|
||||
elif entity.__name__ == "Dataset":
|
||||
ds = session._shared_data.get("dataset")
|
||||
result.all.return_value = [ds] if ds else []
|
||||
else:
|
||||
result.all.return_value = []
|
||||
else:
|
||||
result.all.return_value = []
|
||||
return result
|
||||
|
||||
q.all.side_effect = _docs_all
|
||||
q.first.side_effect = _docs_first
|
||||
return q
|
||||
|
||||
# Default fallback
|
||||
q.first.return_value = None
|
||||
q.all.return_value = []
|
||||
return q
|
||||
|
||||
session.query.side_effect = _query_side_effect
|
||||
session.scalar.side_effect = _scalar_side_effect
|
||||
session.scalars.side_effect = _scalars_side_effect
|
||||
|
||||
# Implement session.begin() context manager that commits on exit
|
||||
session.commit = MagicMock()
|
||||
@ -638,8 +613,6 @@ class TestProgressTracking:
|
||||
wrapper = TaskWrapper(data=next_task_data)
|
||||
mock_redis.rpop.return_value = wrapper.serialize()
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
||||
@ -662,7 +635,6 @@ class TestProgressTracking:
|
||||
"""
|
||||
# Arrange
|
||||
mock_redis.rpop.return_value = None # No more tasks
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@ -780,8 +752,7 @@ class TestErrorHandling:
|
||||
|
||||
If the dataset doesn't exist, the task should exit gracefully.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = None
|
||||
# Arrange - dataset is not in _shared_data (None by default), so scalar() returns None
|
||||
|
||||
# Act
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
@ -806,8 +777,6 @@ class TestErrorHandling:
|
||||
# Set up rpop to return task once for concurrency check
|
||||
mock_redis.rpop.side_effect = [wrapper.serialize(), None]
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
# Make _document_indexing raise an error
|
||||
with patch("tasks.document_indexing_task._document_indexing") as mock_indexing:
|
||||
mock_indexing.side_effect = Exception("Processing failed")
|
||||
@ -844,7 +813,7 @@ class TestErrorHandling:
|
||||
# Mock rpop to return tasks one by one
|
||||
mock_redis.rpop.side_effect = tasks[:concurrency_limit] + [None]
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit):
|
||||
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
|
||||
@ -977,7 +946,7 @@ class TestAdvancedScenarios:
|
||||
|
||||
# Mock rpop to return tasks up to concurrency limit
|
||||
mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit):
|
||||
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
|
||||
@ -1070,7 +1039,7 @@ class TestAdvancedScenarios:
|
||||
|
||||
# Mock rpop to return tasks in FIFO order
|
||||
mock_redis.rpop.side_effect = tasks + [None]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", 3):
|
||||
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
|
||||
@ -1108,7 +1077,7 @@ class TestAdvancedScenarios:
|
||||
"""
|
||||
# Arrange
|
||||
mock_redis.rpop.return_value = None # Empty queue
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
|
||||
# Act
|
||||
@ -1276,7 +1245,7 @@ class TestIntegration:
|
||||
# First call returns task 2, second call returns None
|
||||
mock_redis.rpop.side_effect = [wrapper.serialize(), None]
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@ -1433,7 +1402,7 @@ class TestPerformanceScenarios:
|
||||
|
||||
# Mock rpop to return tasks up to concurrency limit
|
||||
mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit):
|
||||
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
|
||||
@ -1536,10 +1505,8 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
"""Test early return when dataset does not exist."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = None
|
||||
session.query.side_effect = lambda model: dataset_query
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None # dataset not found
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(session))
|
||||
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
|
||||
@ -1560,16 +1527,15 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
document = SimpleNamespace(id="doc-1", indexing_status=None, error=None, stopped_at=None)
|
||||
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.first.return_value = document
|
||||
|
||||
session = MagicMock()
|
||||
session.query.side_effect = lambda model: dataset_query if model is Dataset else document_query
|
||||
|
||||
def _scalar_se(stmt):
|
||||
entity = stmt.column_descriptions[0].get("entity")
|
||||
if entity is Dataset:
|
||||
return dataset
|
||||
return document
|
||||
|
||||
session.scalar.side_effect = _scalar_se
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
@ -1643,9 +1609,12 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: phase1_document_query
|
||||
session3.query.side_effect = lambda model: summary_document_query if model is Document else dataset_query
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=phase1_docs))
|
||||
session3.scalar.return_value = dataset
|
||||
session3.scalars.return_value = MagicMock(
|
||||
all=MagicMock(return_value=[doc_eligible, doc_skip_form, doc_skip_status])
|
||||
)
|
||||
|
||||
create_session_mock = MagicMock(
|
||||
side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]
|
||||
@ -1704,9 +1673,11 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: phase1_query
|
||||
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
|
||||
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
|
||||
session3.scalar.return_value = dataset
|
||||
session3.scalars.return_value = MagicMock(all=MagicMock(return_value=[doc_eligible]))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
@ -1736,21 +1707,14 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
"""Test early return when dataset is missing after indexing."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.side_effect = [dataset, None]
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
session3.query.side_effect = lambda model: dataset_query
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
|
||||
session3.scalar.return_value = None # dataset not found on second query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
@ -1770,7 +1734,7 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
session3.query.assert_called()
|
||||
session3.scalar.assert_called()
|
||||
|
||||
def test_should_skip_summary_when_not_high_quality(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test summary generation skipped when indexing_technique is not high_quality."""
|
||||
@ -1781,21 +1745,14 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
indexing_technique="economy",
|
||||
summary_index_setting={"enable": True},
|
||||
)
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
session3.query.side_effect = lambda model: dataset_query
|
||||
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
|
||||
session3.scalar.return_value = dataset
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
@ -1824,19 +1781,12 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
"""Test summary generation is skipped when indexing is paused."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
|
||||
|
||||
create_session_mock = MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)])
|
||||
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
|
||||
@ -1865,19 +1815,12 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
"""Test generic indexing runner exception is handled."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
@ -1922,25 +1865,15 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
indexing_technique="high_quality",
|
||||
summary_index_setting={"enable": True},
|
||||
)
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
phase1_query = MagicMock()
|
||||
phase1_query.where.return_value = phase1_query
|
||||
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value = summary_query
|
||||
summary_query.all.return_value = [_FalseyDocument("missing-doc")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: phase1_query
|
||||
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
|
||||
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
|
||||
session3.scalar.return_value = dataset
|
||||
session3.scalars.return_value = MagicMock(all=MagicMock(return_value=[_FalseyDocument("missing-doc")]))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
|
||||
@ -3,7 +3,6 @@ from unittest.mock import MagicMock, call, patch
|
||||
import pytest
|
||||
|
||||
from libs.archive_storage import ArchiveStorageNotConfiguredError
|
||||
from models.workflow import WorkflowArchiveLog
|
||||
from tasks.remove_app_and_related_data_task import (
|
||||
_delete_app_workflow_archive_logs,
|
||||
_delete_archived_workflow_run_files,
|
||||
@ -83,16 +82,11 @@ class TestDeleteWorkflowArchiveLogs:
|
||||
assert params == {"tenant_id": tenant_id, "app_id": app_id}
|
||||
assert name == "workflow archive log"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_delete_query = MagicMock()
|
||||
mock_query.where.return_value = mock_delete_query
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_session = MagicMock()
|
||||
|
||||
delete_func(mock_db.session, "log-1")
|
||||
delete_func(mock_session, "log-1")
|
||||
|
||||
mock_db.session.query.assert_called_once_with(WorkflowArchiveLog)
|
||||
mock_query.where.assert_called_once()
|
||||
mock_delete_query.delete.assert_called_once_with(synchronize_session=False)
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
|
||||
class TestDeleteArchivedWorkflowRunFiles:
|
||||
|
||||
6
api/uv.lock
generated
6
api/uv.lock
generated
@ -4763,11 +4763,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.9.2"
|
||||
version = "6.10.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/31/83/691bdb309306232362503083cb15777491045dd54f45393a317dc7d8082f/pypdf-6.9.2.tar.gz", hash = "sha256:7f850faf2b0d4ab936582c05da32c52214c2b089d61a316627b5bfb5b0dab46c", size = 5311837, upload-time = "2026-03-23T14:53:27.983Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b8/9f/ca96abf18683ca12602065e4ed2bec9050b672c87d317f1079abc7b6d993/pypdf-6.10.0.tar.gz", hash = "sha256:4c5a48ba258c37024ec2505f7e8fd858525f5502784a2e1c8d415604af29f6ef", size = 5314833, upload-time = "2026-04-10T09:34:57.102Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/7e/c85f41243086a8fe5d1baeba527cb26a1918158a565932b41e0f7c0b32e9/pypdf-6.9.2-py3-none-any.whl", hash = "sha256:662cf29bcb419a36a1365232449624ab40b7c2d0cfc28e54f42eeecd1fd7e844", size = 333744, upload-time = "2026-03-23T14:53:26.573Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/55/f2/7ebe366f633f30a6ad105f650f44f24f98cb1335c4157d21ae47138b3482/pypdf-6.10.0-py3-none-any.whl", hash = "sha256:90005e959e1596c6e6c84c8b0ad383285b3e17011751cedd17f2ce8fcdfc86de", size = 334459, upload-time = "2026-04-10T09:34:54.966Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@ -27,7 +27,7 @@ describe('NumberInputField', () => {
|
||||
|
||||
it('should update value when users click increment', () => {
|
||||
render(<NumberInputField label="Count" />)
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.increment' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'Increment value' }))
|
||||
expect(mockField.handleChange).toHaveBeenCalledWith(3)
|
||||
})
|
||||
|
||||
|
||||
@ -172,13 +172,13 @@ describe('NumberField wrapper', () => {
|
||||
|
||||
// Increment and decrement buttons should preserve accessible naming, icon fallbacks, and spacing variants.
|
||||
describe('Control buttons', () => {
|
||||
it('should provide localized aria labels and default icons when labels are not provided', () => {
|
||||
it('should provide english fallback aria labels and default icons when labels are not provided', () => {
|
||||
renderNumberField({
|
||||
controlsProps: {},
|
||||
})
|
||||
|
||||
const increment = screen.getByRole('button', { name: 'common.operation.increment' })
|
||||
const decrement = screen.getByRole('button', { name: 'common.operation.decrement' })
|
||||
const increment = screen.getByRole('button', { name: 'Increment value' })
|
||||
const decrement = screen.getByRole('button', { name: 'Decrement value' })
|
||||
|
||||
expect(increment.querySelector('.i-ri-arrow-up-s-line')).toBeInTheDocument()
|
||||
expect(decrement.querySelector('.i-ri-arrow-down-s-line')).toBeInTheDocument()
|
||||
@ -217,11 +217,11 @@ describe('NumberField wrapper', () => {
|
||||
},
|
||||
})
|
||||
|
||||
expect(screen.getByRole('button', { name: 'common.operation.increment' })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.decrement' })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'Increment value' })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'Decrement value' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should rely on aria-labelledby when provided instead of injecting a translated aria-label', () => {
|
||||
it('should rely on aria-labelledby when provided instead of injecting a fallback aria-label', () => {
|
||||
render(
|
||||
<>
|
||||
<span id="increment-label">Increment from label</span>
|
||||
|
||||
@ -4,7 +4,6 @@ import type { VariantProps } from 'class-variance-authority'
|
||||
import { NumberField as BaseNumberField } from '@base-ui/react/number-field'
|
||||
import { cva } from 'class-variance-authority'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
export const NumberField = BaseNumberField.Root
|
||||
@ -188,18 +187,19 @@ type NumberFieldButtonVariantProps = Omit<
|
||||
|
||||
export type NumberFieldButtonProps = React.ComponentPropsWithoutRef<typeof BaseNumberField.Increment> & NumberFieldButtonVariantProps
|
||||
|
||||
const incrementAriaLabel = 'Increment value'
|
||||
const decrementAriaLabel = 'Decrement value'
|
||||
|
||||
export function NumberFieldIncrement({
|
||||
className,
|
||||
children,
|
||||
size = 'regular',
|
||||
...props
|
||||
}: NumberFieldButtonProps) {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<BaseNumberField.Increment
|
||||
{...props}
|
||||
aria-label={props['aria-label'] ?? (props['aria-labelledby'] ? undefined : t('operation.increment', { ns: 'common' }))}
|
||||
aria-label={props['aria-label'] ?? (props['aria-labelledby'] ? undefined : incrementAriaLabel)}
|
||||
className={cn(numberFieldControlButtonVariants({ size, direction: 'increment' }), className)}
|
||||
>
|
||||
{children ?? <span aria-hidden="true" className="i-ri-arrow-up-s-line size-3" />}
|
||||
@ -213,12 +213,10 @@ export function NumberFieldDecrement({
|
||||
size = 'regular',
|
||||
...props
|
||||
}: NumberFieldButtonProps) {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<BaseNumberField.Decrement
|
||||
{...props}
|
||||
aria-label={props['aria-label'] ?? (props['aria-labelledby'] ? undefined : t('operation.decrement', { ns: 'common' }))}
|
||||
aria-label={props['aria-label'] ?? (props['aria-labelledby'] ? undefined : decrementAriaLabel)}
|
||||
className={cn(numberFieldControlButtonVariants({ size, direction: 'decrement' }), className)}
|
||||
>
|
||||
{children ?? <span aria-hidden="true" className="i-ri-arrow-down-s-line size-3" />}
|
||||
|
||||
@ -31,13 +31,13 @@ describe('base/ui/toast', () => {
|
||||
|
||||
expect(await screen.findByText('Saved')).toBeInTheDocument()
|
||||
expect(screen.getByText('Your changes are available now.')).toBeInTheDocument()
|
||||
const viewport = screen.getByRole('region', { name: 'common.toast.notifications' })
|
||||
const viewport = screen.getByRole('region', { name: 'Notifications' })
|
||||
expect(viewport).toHaveAttribute('aria-live', 'polite')
|
||||
expect(viewport).toHaveClass('z-1101')
|
||||
expect(viewport.firstElementChild).toHaveClass('top-4')
|
||||
expect(screen.getByRole('dialog')).not.toHaveClass('outline-hidden')
|
||||
expect(document.body.querySelector('[aria-hidden="true"].i-ri-checkbox-circle-fill')).toBeInTheDocument()
|
||||
expect(document.body.querySelector('button[aria-label="common.toast.close"][aria-hidden="true"]')).toBeInTheDocument()
|
||||
expect(document.body.querySelector('button[aria-label="Close notification"][aria-hidden="true"]')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Collapsed stacks should keep multiple toast roots mounted for smooth stack animation.
|
||||
@ -57,12 +57,12 @@ describe('base/ui/toast', () => {
|
||||
|
||||
expect(await screen.findByText('Third toast')).toBeInTheDocument()
|
||||
expect(screen.getAllByRole('dialog')).toHaveLength(3)
|
||||
expect(document.body.querySelectorAll('button[aria-label="common.toast.close"][aria-hidden="true"]')).toHaveLength(3)
|
||||
expect(document.body.querySelectorAll('button[aria-label="Close notification"][aria-hidden="true"]')).toHaveLength(3)
|
||||
|
||||
fireEvent.mouseEnter(screen.getByRole('region', { name: 'common.toast.notifications' }))
|
||||
fireEvent.mouseEnter(screen.getByRole('region', { name: 'Notifications' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(document.body.querySelector('button[aria-label="common.toast.close"][aria-hidden="true"]')).not.toBeInTheDocument()
|
||||
expect(document.body.querySelector('button[aria-label="Close notification"][aria-hidden="true"]')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@ -126,9 +126,9 @@ describe('base/ui/toast', () => {
|
||||
})
|
||||
})
|
||||
|
||||
fireEvent.mouseEnter(screen.getByRole('region', { name: 'common.toast.notifications' }))
|
||||
fireEvent.mouseEnter(screen.getByRole('region', { name: 'Notifications' }))
|
||||
|
||||
const dismissButton = await screen.findByRole('button', { name: 'common.toast.close' })
|
||||
const dismissButton = await screen.findByRole('button', { name: 'Close notification' })
|
||||
|
||||
act(() => {
|
||||
dismissButton.click()
|
||||
|
||||
@ -7,7 +7,6 @@ import type {
|
||||
} from '@base-ui/react/toast'
|
||||
import type { ReactNode } from 'react'
|
||||
import { Toast as BaseToast } from '@base-ui/react/toast'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type ToastData = Record<string, never>
|
||||
@ -35,6 +34,9 @@ const TOAST_TONE_STYLES = {
|
||||
},
|
||||
} satisfies Record<string, ToastToneStyle>
|
||||
|
||||
const toastCloseLabel = 'Close notification'
|
||||
const toastViewportLabel = 'Notifications'
|
||||
|
||||
type ToastType = keyof typeof TOAST_TONE_STYLES
|
||||
|
||||
type ToastAddOptions = Omit<ToastManagerAddOptions<ToastData>, 'data' | 'positionerProps' | 'type'> & {
|
||||
@ -145,7 +147,6 @@ function ToastCard({
|
||||
}: {
|
||||
toast: ToastObject<ToastData>
|
||||
}) {
|
||||
const { t } = useTranslation('common')
|
||||
const toastType = getToastType(toastItem.type)
|
||||
|
||||
return (
|
||||
@ -200,7 +201,7 @@ function ToastCard({
|
||||
</div>
|
||||
<div className="flex shrink-0 items-center justify-center rounded-md p-0.5">
|
||||
<BaseToast.Close
|
||||
aria-label={t('toast.close')}
|
||||
aria-label={toastCloseLabel}
|
||||
className={cn(
|
||||
'flex h-5 w-5 items-center justify-center rounded-md hover:bg-state-base-hover focus-visible:bg-state-base-hover focus-visible:ring-1 focus-visible:ring-components-input-border-hover focus-visible:outline-hidden disabled:cursor-not-allowed disabled:opacity-50',
|
||||
)}
|
||||
@ -215,12 +216,11 @@ function ToastCard({
|
||||
}
|
||||
|
||||
function ToastViewport() {
|
||||
const { t } = useTranslation('common')
|
||||
const { toasts } = BaseToast.useToastManager<ToastData>()
|
||||
|
||||
return (
|
||||
<BaseToast.Viewport
|
||||
aria-label={t('toast.notifications')}
|
||||
aria-label={toastViewportLabel}
|
||||
className={cn(
|
||||
// During overlay migration, toast must stay above legacy highPriority modals (z-[1100]).
|
||||
'inset-0 group/toast-viewport pointer-events-none fixed z-1101 overflow-visible',
|
||||
|
||||
46
web/app/components/workflow/__tests__/block-icon.spec.tsx
Normal file
46
web/app/components/workflow/__tests__/block-icon.spec.tsx
Normal file
@ -0,0 +1,46 @@
|
||||
import { render } from '@testing-library/react'
|
||||
import { API_PREFIX } from '@/config'
|
||||
import BlockIcon, { VarBlockIcon } from '../block-icon'
|
||||
import { BlockEnum } from '../types'
|
||||
|
||||
describe('BlockIcon', () => {
|
||||
it('renders the default workflow icon container for regular nodes', () => {
|
||||
const { container } = render(<BlockIcon type={BlockEnum.Start} size="xs" className="extra-class" />)
|
||||
|
||||
const iconContainer = container.firstElementChild
|
||||
expect(iconContainer).toHaveClass('w-4', 'h-4', 'bg-util-colors-blue-brand-blue-brand-500', 'extra-class')
|
||||
expect(iconContainer?.querySelector('svg')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('normalizes protected plugin icon urls for tool-like nodes', () => {
|
||||
const { container } = render(
|
||||
<BlockIcon
|
||||
type={BlockEnum.Tool}
|
||||
toolIcon="/foo/workspaces/current/plugin/icon/plugin-tool.png"
|
||||
/>,
|
||||
)
|
||||
|
||||
const iconContainer = container.firstElementChild as HTMLElement
|
||||
const backgroundIcon = iconContainer.querySelector('div') as HTMLElement
|
||||
|
||||
expect(iconContainer).not.toHaveClass('bg-util-colors-blue-blue-500')
|
||||
expect(backgroundIcon.style.backgroundImage).toContain(
|
||||
`${API_PREFIX}/workspaces/current/plugin/icon/plugin-tool.png`,
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('VarBlockIcon', () => {
|
||||
it('renders the compact icon variant without the default container wrapper', () => {
|
||||
const { container } = render(
|
||||
<VarBlockIcon
|
||||
type={BlockEnum.Answer}
|
||||
className="custom-var-icon"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(container.querySelector('.custom-var-icon')).toBeInTheDocument()
|
||||
expect(container.querySelector('svg')).toBeInTheDocument()
|
||||
expect(container.querySelector('.bg-util-colors-warning-warning-500')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
39
web/app/components/workflow/__tests__/context.spec.tsx
Normal file
39
web/app/components/workflow/__tests__/context.spec.tsx
Normal file
@ -0,0 +1,39 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { WorkflowContextProvider } from '../context'
|
||||
import { useStore, useWorkflowStore } from '../store'
|
||||
|
||||
const StoreConsumer = () => {
|
||||
const showSingleRunPanel = useStore(s => s.showSingleRunPanel)
|
||||
const store = useWorkflowStore()
|
||||
|
||||
return (
|
||||
<button onClick={() => store.getState().setShowSingleRunPanel(!showSingleRunPanel)}>
|
||||
{showSingleRunPanel ? 'open' : 'closed'}
|
||||
</button>
|
||||
)
|
||||
}
|
||||
|
||||
describe('WorkflowContextProvider', () => {
|
||||
it('provides the workflow store to descendants and keeps the same store across rerenders', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { rerender } = render(
|
||||
<WorkflowContextProvider>
|
||||
<StoreConsumer />
|
||||
</WorkflowContextProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('button', { name: 'closed' })).toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'closed' }))
|
||||
expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument()
|
||||
|
||||
rerender(
|
||||
<WorkflowContextProvider>
|
||||
<StoreConsumer />
|
||||
</WorkflowContextProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
67
web/app/components/workflow/__tests__/index.spec.tsx
Normal file
67
web/app/components/workflow/__tests__/index.spec.tsx
Normal file
@ -0,0 +1,67 @@
|
||||
import type { Edge, Node } from '../types'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { useStoreApi } from 'reactflow'
|
||||
import { useDatasetsDetailStore } from '../datasets-detail-store/store'
|
||||
import WorkflowWithDefaultContext from '../index'
|
||||
import { BlockEnum } from '../types'
|
||||
import { useWorkflowHistoryStore } from '../workflow-history-store'
|
||||
|
||||
const nodes: Node[] = [
|
||||
{
|
||||
id: 'node-start',
|
||||
type: 'custom',
|
||||
position: { x: 0, y: 0 },
|
||||
data: {
|
||||
title: 'Start',
|
||||
desc: '',
|
||||
type: BlockEnum.Start,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
const edges: Edge[] = [
|
||||
{
|
||||
id: 'edge-1',
|
||||
source: 'node-start',
|
||||
target: 'node-end',
|
||||
sourceHandle: null,
|
||||
targetHandle: null,
|
||||
type: 'custom',
|
||||
data: {
|
||||
sourceType: BlockEnum.Start,
|
||||
targetType: BlockEnum.End,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
const ContextConsumer = () => {
|
||||
const { store, shortcutsEnabled } = useWorkflowHistoryStore()
|
||||
const datasetCount = useDatasetsDetailStore(state => Object.keys(state.datasetsDetail).length)
|
||||
const reactFlowStore = useStoreApi()
|
||||
|
||||
return (
|
||||
<div>
|
||||
{`history:${store.getState().nodes.length}`}
|
||||
{` shortcuts:${String(shortcutsEnabled)}`}
|
||||
{` datasets:${datasetCount}`}
|
||||
{` reactflow:${String(!!reactFlowStore)}`}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
describe('WorkflowWithDefaultContext', () => {
|
||||
it('wires the ReactFlow, workflow history, and datasets detail providers around its children', () => {
|
||||
render(
|
||||
<WorkflowWithDefaultContext
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
>
|
||||
<ContextConsumer />
|
||||
</WorkflowWithDefaultContext>,
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.getByText('history:1 shortcuts:true datasets:0 reactflow:true'),
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,51 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import ShortcutsName from '../shortcuts-name'
|
||||
|
||||
describe('ShortcutsName', () => {
|
||||
const originalNavigator = globalThis.navigator
|
||||
|
||||
afterEach(() => {
|
||||
Object.defineProperty(globalThis, 'navigator', {
|
||||
value: originalNavigator,
|
||||
writable: true,
|
||||
configurable: true,
|
||||
})
|
||||
})
|
||||
|
||||
it('renders mac-friendly key labels and style variants', () => {
|
||||
Object.defineProperty(globalThis, 'navigator', {
|
||||
value: { userAgent: 'Macintosh' },
|
||||
writable: true,
|
||||
configurable: true,
|
||||
})
|
||||
|
||||
const { container } = render(
|
||||
<ShortcutsName
|
||||
keys={['ctrl', 'shift', 's']}
|
||||
bgColor="white"
|
||||
textColor="secondary"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('⌘')).toBeInTheDocument()
|
||||
expect(screen.getByText('⇧')).toBeInTheDocument()
|
||||
expect(screen.getByText('s')).toBeInTheDocument()
|
||||
expect(container.querySelector('.system-kbd')).toHaveClass(
|
||||
'bg-components-kbd-bg-white',
|
||||
'text-text-tertiary',
|
||||
)
|
||||
})
|
||||
|
||||
it('keeps raw key names on non-mac systems', () => {
|
||||
Object.defineProperty(globalThis, 'navigator', {
|
||||
value: { userAgent: 'Windows NT' },
|
||||
writable: true,
|
||||
configurable: true,
|
||||
})
|
||||
|
||||
render(<ShortcutsName keys={['ctrl', 'alt']} />)
|
||||
|
||||
expect(screen.getByText('ctrl')).toBeInTheDocument()
|
||||
expect(screen.getByText('alt')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,97 @@
|
||||
import type { Edge, Node } from '../types'
|
||||
import type { WorkflowHistoryState } from '../workflow-history-store'
|
||||
import { render, renderHook, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { BlockEnum } from '../types'
|
||||
import { useWorkflowHistoryStore, WorkflowHistoryProvider } from '../workflow-history-store'
|
||||
|
||||
const nodes: Node[] = [
|
||||
{
|
||||
id: 'node-1',
|
||||
type: 'custom',
|
||||
position: { x: 0, y: 0 },
|
||||
data: {
|
||||
title: 'Start',
|
||||
desc: '',
|
||||
type: BlockEnum.Start,
|
||||
selected: true,
|
||||
},
|
||||
selected: true,
|
||||
},
|
||||
]
|
||||
|
||||
const edges: Edge[] = [
|
||||
{
|
||||
id: 'edge-1',
|
||||
source: 'node-1',
|
||||
target: 'node-2',
|
||||
sourceHandle: null,
|
||||
targetHandle: null,
|
||||
type: 'custom',
|
||||
selected: true,
|
||||
data: {
|
||||
sourceType: BlockEnum.Start,
|
||||
targetType: BlockEnum.End,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
const HistoryConsumer = () => {
|
||||
const { store, shortcutsEnabled, setShortcutsEnabled } = useWorkflowHistoryStore()
|
||||
|
||||
return (
|
||||
<button onClick={() => setShortcutsEnabled(!shortcutsEnabled)}>
|
||||
{`nodes:${store.getState().nodes.length} shortcuts:${String(shortcutsEnabled)}`}
|
||||
</button>
|
||||
)
|
||||
}
|
||||
|
||||
describe('WorkflowHistoryProvider', () => {
|
||||
it('provides workflow history state and shortcut toggles', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<WorkflowHistoryProvider
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
>
|
||||
<HistoryConsumer />
|
||||
</WorkflowHistoryProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' })).toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' }))
|
||||
expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:false' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('sanitizes selected flags when history state is replaced through the exposed store api', () => {
|
||||
const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
<WorkflowHistoryProvider
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
>
|
||||
{children}
|
||||
</WorkflowHistoryProvider>
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useWorkflowHistoryStore(), { wrapper })
|
||||
const nextState: WorkflowHistoryState = {
|
||||
workflowHistoryEvent: undefined,
|
||||
workflowHistoryEventMeta: undefined,
|
||||
nodes,
|
||||
edges,
|
||||
}
|
||||
|
||||
result.current.store.setState(nextState)
|
||||
|
||||
expect(result.current.store.getState().nodes[0].data.selected).toBe(false)
|
||||
expect(result.current.store.getState().edges[0].selected).toBe(false)
|
||||
})
|
||||
|
||||
it('throws when consumed outside the provider', () => {
|
||||
expect(() => renderHook(() => useWorkflowHistoryStore())).toThrow(
|
||||
'useWorkflowHistoryStoreApi must be used within a WorkflowHistoryProvider',
|
||||
)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,140 @@
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { useMarketplacePlugins } from '@/app/components/plugins/marketplace/hooks'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { Theme } from '@/types/app'
|
||||
import AllTools from '../all-tools'
|
||||
import { createGlobalPublicStoreState, createToolProvider } from './factories'
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useGetLanguage: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
|
||||
useMarketplacePlugins: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
|
||||
useMCPToolAvailability: () => ({
|
||||
allowed: true,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/var', async importOriginal => ({
|
||||
...(await importOriginal<typeof import('@/utils/var')>()),
|
||||
getMarketplaceUrl: () => 'https://marketplace.test/tools',
|
||||
}))
|
||||
|
||||
const mockUseMarketplacePlugins = vi.mocked(useMarketplacePlugins)
|
||||
const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore)
|
||||
const mockUseGetLanguage = vi.mocked(useGetLanguage)
|
||||
const mockUseTheme = vi.mocked(useTheme)
|
||||
|
||||
const createMarketplacePluginsMock = () => ({
|
||||
plugins: [],
|
||||
total: 0,
|
||||
resetPlugins: vi.fn(),
|
||||
queryPlugins: vi.fn(),
|
||||
queryPluginsWithDebounced: vi.fn(),
|
||||
cancelQueryPluginsWithDebounced: vi.fn(),
|
||||
isLoading: false,
|
||||
isFetchingNextPage: false,
|
||||
hasNextPage: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
page: 0,
|
||||
})
|
||||
|
||||
describe('AllTools', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseGlobalPublicStore.mockImplementation(selector => selector(createGlobalPublicStoreState(false)))
|
||||
mockUseGetLanguage.mockReturnValue('en_US')
|
||||
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
|
||||
mockUseMarketplacePlugins.mockReturnValue(createMarketplacePluginsMock())
|
||||
})
|
||||
|
||||
it('filters tools by the active tab', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<AllTools
|
||||
searchText=""
|
||||
tags={[]}
|
||||
onSelect={vi.fn()}
|
||||
buildInTools={[createToolProvider({
|
||||
id: 'provider-built-in',
|
||||
label: { en_US: 'Built In Provider', zh_Hans: 'Built In Provider' },
|
||||
})]}
|
||||
customTools={[createToolProvider({
|
||||
id: 'provider-custom',
|
||||
type: 'custom',
|
||||
label: { en_US: 'Custom Provider', zh_Hans: 'Custom Provider' },
|
||||
})]}
|
||||
workflowTools={[]}
|
||||
mcpTools={[]}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Built In Provider')).toBeInTheDocument()
|
||||
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByText('workflow.tabs.customTool'))
|
||||
|
||||
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Built In Provider')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('filters the rendered tools by the search text', () => {
|
||||
render(
|
||||
<AllTools
|
||||
searchText="report"
|
||||
tags={[]}
|
||||
onSelect={vi.fn()}
|
||||
buildInTools={[
|
||||
createToolProvider({
|
||||
id: 'provider-report',
|
||||
label: { en_US: 'Report Toolkit', zh_Hans: 'Report Toolkit' },
|
||||
}),
|
||||
createToolProvider({
|
||||
id: 'provider-other',
|
||||
label: { en_US: 'Other Toolkit', zh_Hans: 'Other Toolkit' },
|
||||
}),
|
||||
]}
|
||||
customTools={[]}
|
||||
workflowTools={[]}
|
||||
mcpTools={[]}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Report Toolkit')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Other Toolkit')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows the empty state when no tool matches the current filter', async () => {
|
||||
render(
|
||||
<AllTools
|
||||
searchText="missing"
|
||||
tags={[]}
|
||||
onSelect={vi.fn()}
|
||||
buildInTools={[]}
|
||||
customTools={[]}
|
||||
workflowTools={[]}
|
||||
mcpTools={[]}
|
||||
/>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('workflow.tabs.noPluginsFound')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,51 +1,110 @@
|
||||
import type { NodeDefault } from '../../types'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { useStoreApi } from 'reactflow'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { AppTypeEnum } from '@/types/app'
|
||||
import { BlockEnum } from '../../types'
|
||||
import Blocks from '../blocks'
|
||||
import { BlockClassificationEnum } from '../types'
|
||||
|
||||
const mockGetNodes = vi.fn(() => [])
|
||||
const mockAppType = vi.hoisted<{ current?: string }>(() => ({
|
||||
current: 'workflow',
|
||||
const runtimeState = vi.hoisted(() => ({
|
||||
appType: 'workflow' as string | undefined,
|
||||
nodes: [] as Array<{ data: { type?: BlockEnum } }>,
|
||||
}))
|
||||
|
||||
vi.mock('reactflow', () => ({
|
||||
useStoreApi: vi.fn(),
|
||||
useStoreApi: () => ({
|
||||
getState: () => ({
|
||||
getNodes: () => runtimeState.nodes,
|
||||
}),
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
useStore: (selector: (state: { appDetail: { type?: string } }) => unknown) => selector({
|
||||
appDetail: {
|
||||
type: mockAppType.current,
|
||||
type: runtimeState.appType,
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockUseStoreApi = vi.mocked(useStoreApi)
|
||||
const createBlock = (type: BlockEnum, title: string, classification = BlockClassificationEnum.Default): NodeDefault => ({
|
||||
metaData: {
|
||||
classification,
|
||||
sort: 0,
|
||||
type,
|
||||
title,
|
||||
author: 'Dify',
|
||||
description: `${title} description`,
|
||||
},
|
||||
defaultValue: {},
|
||||
checkValid: () => ({ isValid: true }),
|
||||
})
|
||||
|
||||
describe('Blocks', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockAppType.current = AppTypeEnum.WORKFLOW
|
||||
mockUseStoreApi.mockReturnValue({
|
||||
getState: () => ({
|
||||
getNodes: mockGetNodes,
|
||||
}),
|
||||
} as unknown as ReturnType<typeof useStoreApi>)
|
||||
runtimeState.appType = AppTypeEnum.WORKFLOW
|
||||
runtimeState.nodes = []
|
||||
})
|
||||
|
||||
it('should hide human input in evaluation workflows', () => {
|
||||
mockAppType.current = AppTypeEnum.EVALUATION
|
||||
it('should hide human input blocks when the app is an evaluation workflow', () => {
|
||||
runtimeState.appType = AppTypeEnum.EVALUATION
|
||||
|
||||
render(
|
||||
<Blocks
|
||||
searchText=""
|
||||
onSelect={vi.fn()}
|
||||
availableBlocksTypes={[BlockEnum.HumanInput, BlockEnum.LLM]}
|
||||
blocks={[
|
||||
createBlock(BlockEnum.HumanInput, 'Human Input'),
|
||||
createBlock(BlockEnum.LLM, 'LLM'),
|
||||
]}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('workflow.blocks.human-input')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('workflow.blocks.llm')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Human Input')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('LLM')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render grouped blocks, filter duplicate knowledge-base nodes, and select a block', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
runtimeState.nodes = [{ data: { type: BlockEnum.KnowledgeBase } }]
|
||||
|
||||
render(
|
||||
<Blocks
|
||||
searchText=""
|
||||
onSelect={onSelect}
|
||||
availableBlocksTypes={[BlockEnum.LLM, BlockEnum.LoopEnd, BlockEnum.KnowledgeBase]}
|
||||
blocks={[
|
||||
createBlock(BlockEnum.LLM, 'LLM'),
|
||||
createBlock(BlockEnum.LoopEnd, 'Exit Loop', BlockClassificationEnum.Logic),
|
||||
createBlock(BlockEnum.KnowledgeBase, 'Knowledge Retrieval'),
|
||||
]}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('LLM')).toBeInTheDocument()
|
||||
expect(screen.getByText('Exit Loop')).toBeInTheDocument()
|
||||
expect(screen.getByText('workflow.nodes.loop.loopNode')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Knowledge Retrieval')).not.toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByText('LLM'))
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(BlockEnum.LLM)
|
||||
})
|
||||
|
||||
it('should show the empty state when no block matches the search text', () => {
|
||||
render(
|
||||
<Blocks
|
||||
searchText="missing"
|
||||
onSelect={vi.fn()}
|
||||
availableBlocksTypes={[BlockEnum.LLM]}
|
||||
blocks={[createBlock(BlockEnum.LLM, 'LLM')]}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('workflow.tabs.noResult')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@ -0,0 +1,101 @@
|
||||
import type { ToolWithProvider } from '../../types'
|
||||
import type { Plugin } from '@/app/components/plugins/types'
|
||||
import type { Tool } from '@/app/components/tools/types'
|
||||
import { PluginCategoryEnum } from '@/app/components/plugins/types'
|
||||
import { CollectionType } from '@/app/components/tools/types'
|
||||
import { defaultSystemFeatures } from '@/types/feature'
|
||||
|
||||
export const createTool = (
|
||||
name: string,
|
||||
label: string,
|
||||
description = `${label} description`,
|
||||
): Tool => ({
|
||||
name,
|
||||
author: 'author',
|
||||
label: {
|
||||
en_US: label,
|
||||
zh_Hans: label,
|
||||
},
|
||||
description: {
|
||||
en_US: description,
|
||||
zh_Hans: description,
|
||||
},
|
||||
parameters: [],
|
||||
labels: [],
|
||||
output_schema: {},
|
||||
})
|
||||
|
||||
export const createToolProvider = (
|
||||
overrides: Partial<ToolWithProvider> = {},
|
||||
): ToolWithProvider => ({
|
||||
id: 'provider-1',
|
||||
name: 'provider-one',
|
||||
author: 'Provider Author',
|
||||
description: {
|
||||
en_US: 'Provider description',
|
||||
zh_Hans: 'Provider description',
|
||||
},
|
||||
icon: 'icon',
|
||||
icon_dark: 'icon-dark',
|
||||
label: {
|
||||
en_US: 'Provider One',
|
||||
zh_Hans: 'Provider One',
|
||||
},
|
||||
type: CollectionType.builtIn,
|
||||
team_credentials: {},
|
||||
is_team_authorization: false,
|
||||
allow_delete: false,
|
||||
labels: [],
|
||||
plugin_id: 'plugin-1',
|
||||
tools: [createTool('tool-a', 'Tool A')],
|
||||
meta: { version: '1.0.0' } as ToolWithProvider['meta'],
|
||||
plugin_unique_identifier: 'plugin-1@1.0.0',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
export const createPlugin = (overrides: Partial<Plugin> = {}): Plugin => ({
|
||||
type: 'plugin',
|
||||
org: 'org',
|
||||
author: 'author',
|
||||
name: 'Plugin One',
|
||||
plugin_id: 'plugin-1',
|
||||
version: '1.0.0',
|
||||
latest_version: '1.0.0',
|
||||
latest_package_identifier: 'plugin-1@1.0.0',
|
||||
icon: 'icon',
|
||||
verified: true,
|
||||
label: {
|
||||
en_US: 'Plugin One',
|
||||
zh_Hans: 'Plugin One',
|
||||
},
|
||||
brief: {
|
||||
en_US: 'Plugin description',
|
||||
zh_Hans: 'Plugin description',
|
||||
},
|
||||
description: {
|
||||
en_US: 'Plugin description',
|
||||
zh_Hans: 'Plugin description',
|
||||
},
|
||||
introduction: 'Plugin introduction',
|
||||
repository: 'https://example.com/plugin',
|
||||
category: PluginCategoryEnum.tool,
|
||||
tags: [],
|
||||
badges: [],
|
||||
install_count: 0,
|
||||
endpoint: {
|
||||
settings: [],
|
||||
},
|
||||
verification: {
|
||||
authorized_category: 'community',
|
||||
},
|
||||
from: 'github',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
export const createGlobalPublicStoreState = (enableMarketplace: boolean) => ({
|
||||
systemFeatures: {
|
||||
...defaultSystemFeatures,
|
||||
enable_marketplace: enableMarketplace,
|
||||
},
|
||||
setSystemFeatures: vi.fn(),
|
||||
})
|
||||
@ -0,0 +1,101 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { Theme } from '@/types/app'
|
||||
import FeaturedTools from '../featured-tools'
|
||||
import { createPlugin, createToolProvider } from './factories'
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useGetLanguage: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
|
||||
useMCPToolAvailability: () => ({
|
||||
allowed: true,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/var', async importOriginal => ({
|
||||
...(await importOriginal<typeof import('@/utils/var')>()),
|
||||
getMarketplaceUrl: () => 'https://marketplace.test/tools',
|
||||
}))
|
||||
|
||||
const mockUseGetLanguage = vi.mocked(useGetLanguage)
|
||||
const mockUseTheme = vi.mocked(useTheme)
|
||||
|
||||
describe('FeaturedTools', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
localStorage.clear()
|
||||
mockUseGetLanguage.mockReturnValue('en_US')
|
||||
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
|
||||
})
|
||||
|
||||
it('shows more featured tools when the list exceeds the initial quota', async () => {
|
||||
const user = userEvent.setup()
|
||||
const plugins = Array.from({ length: 6 }, (_, index) =>
|
||||
createPlugin({
|
||||
plugin_id: `plugin-${index + 1}`,
|
||||
latest_package_identifier: `plugin-${index + 1}@1.0.0`,
|
||||
label: { en_US: `Plugin ${index + 1}`, zh_Hans: `Plugin ${index + 1}` },
|
||||
}))
|
||||
const providers = plugins.map((plugin, index) =>
|
||||
createToolProvider({
|
||||
id: `provider-${index + 1}`,
|
||||
plugin_id: plugin.plugin_id,
|
||||
label: { en_US: `Provider ${index + 1}`, zh_Hans: `Provider ${index + 1}` },
|
||||
}),
|
||||
)
|
||||
const providerMap = new Map(providers.map(provider => [provider.plugin_id!, provider]))
|
||||
|
||||
render(
|
||||
<FeaturedTools
|
||||
plugins={plugins}
|
||||
providerMap={providerMap}
|
||||
onSelect={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Provider 1')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Provider 6')).not.toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByText('workflow.tabs.showMoreFeatured'))
|
||||
|
||||
expect(screen.getByText('Provider 6')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('honors the persisted collapsed state', () => {
|
||||
localStorage.setItem('workflow_tools_featured_collapsed', 'true')
|
||||
|
||||
render(
|
||||
<FeaturedTools
|
||||
plugins={[createPlugin()]}
|
||||
providerMap={new Map([[
|
||||
'plugin-1',
|
||||
createToolProvider(),
|
||||
]])}
|
||||
onSelect={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('workflow.tabs.featuredTools')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Provider One')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows the marketplace empty state when no featured tools are available', () => {
|
||||
render(
|
||||
<FeaturedTools
|
||||
plugins={[]}
|
||||
providerMap={new Map()}
|
||||
onSelect={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('workflow.tabs.noFeaturedPlugins')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,52 @@
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { useTabs, useToolTabs } from '../hooks'
|
||||
import { TabsEnum, ToolTypeEnum } from '../types'
|
||||
|
||||
describe('block-selector hooks', () => {
|
||||
it('falls back to the first valid tab when the preferred start tab is disabled', () => {
|
||||
const { result } = renderHook(() => useTabs({
|
||||
noStart: false,
|
||||
hasUserInputNode: true,
|
||||
defaultActiveTab: TabsEnum.Start,
|
||||
}))
|
||||
|
||||
expect(result.current.tabs.find(tab => tab.key === TabsEnum.Start)?.disabled).toBe(true)
|
||||
expect(result.current.activeTab).toBe(TabsEnum.Blocks)
|
||||
})
|
||||
|
||||
it('keeps the start tab enabled when forcing it on and resets to a valid tab after disabling blocks', () => {
|
||||
const props: Parameters<typeof useTabs>[0] = {
|
||||
noBlocks: false,
|
||||
noStart: false,
|
||||
hasUserInputNode: true,
|
||||
forceEnableStartTab: true,
|
||||
}
|
||||
|
||||
const { result, rerender } = renderHook(nextProps => useTabs(nextProps), {
|
||||
initialProps: props,
|
||||
})
|
||||
|
||||
expect(result.current.tabs.find(tab => tab.key === TabsEnum.Start)?.disabled).toBeFalsy()
|
||||
|
||||
act(() => {
|
||||
result.current.setActiveTab(TabsEnum.Blocks)
|
||||
})
|
||||
|
||||
rerender({
|
||||
...props,
|
||||
noBlocks: true,
|
||||
noSources: true,
|
||||
noTools: true,
|
||||
})
|
||||
|
||||
expect(result.current.activeTab).toBe(TabsEnum.Start)
|
||||
})
|
||||
|
||||
it('returns the MCP tab only when it is not hidden', () => {
|
||||
const { result: visible } = renderHook(() => useToolTabs())
|
||||
const { result: hidden } = renderHook(() => useToolTabs(true))
|
||||
|
||||
expect(visible.current.some(tab => tab.key === ToolTypeEnum.MCP)).toBe(true)
|
||||
expect(hidden.current.some(tab => tab.key === ToolTypeEnum.MCP)).toBe(false)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,90 @@
|
||||
import type { NodeDefault, ToolWithProvider } from '../../types'
|
||||
import { screen } from '@testing-library/react'
|
||||
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
|
||||
import { BlockEnum } from '../../types'
|
||||
import NodeSelectorWrapper from '../index'
|
||||
import { BlockClassificationEnum } from '../types'
|
||||
|
||||
vi.mock('reactflow', async () =>
|
||||
(await import('../../__tests__/reactflow-mock-state')).createReactFlowModuleMock())
|
||||
|
||||
vi.mock('@/service/use-plugins', () => ({
|
||||
useFeaturedToolsRecommendations: () => ({
|
||||
plugins: [],
|
||||
isLoading: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-tools', () => ({
|
||||
useAllBuiltInTools: () => ({ data: [] }),
|
||||
useAllCustomTools: () => ({ data: [] }),
|
||||
useAllWorkflowTools: () => ({ data: [] }),
|
||||
useAllMCPTools: () => ({ data: [] }),
|
||||
useInvalidateAllBuiltInTools: () => vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_marketplace: boolean } }) => unknown) => selector({
|
||||
systemFeatures: { enable_marketplace: false },
|
||||
}),
|
||||
}))
|
||||
|
||||
const createBlock = (type: BlockEnum, title: string): NodeDefault => ({
|
||||
metaData: {
|
||||
type,
|
||||
title,
|
||||
sort: 0,
|
||||
classification: BlockClassificationEnum.Default,
|
||||
author: 'Dify',
|
||||
description: `${title} description`,
|
||||
},
|
||||
defaultValue: {},
|
||||
checkValid: () => ({ isValid: true }),
|
||||
})
|
||||
|
||||
const dataSource: ToolWithProvider = {
|
||||
id: 'datasource-1',
|
||||
name: 'datasource',
|
||||
author: 'Dify',
|
||||
description: { en_US: 'Data source', zh_Hans: '数据源' },
|
||||
icon: 'icon',
|
||||
label: { en_US: 'Data Source', zh_Hans: 'Data Source' },
|
||||
type: 'datasource' as ToolWithProvider['type'],
|
||||
team_credentials: {},
|
||||
is_team_authorization: false,
|
||||
allow_delete: false,
|
||||
labels: [],
|
||||
tools: [],
|
||||
meta: { version: '1.0.0' } as ToolWithProvider['meta'],
|
||||
}
|
||||
|
||||
describe('NodeSelectorWrapper', () => {
|
||||
it('filters hidden block types from hooks store and forwards data sources', async () => {
|
||||
renderWorkflowComponent(
|
||||
<NodeSelectorWrapper
|
||||
open
|
||||
onSelect={vi.fn()}
|
||||
availableBlocksTypes={[BlockEnum.Code]}
|
||||
/>,
|
||||
{
|
||||
hooksStoreProps: {
|
||||
availableNodesMetaData: {
|
||||
nodes: [
|
||||
createBlock(BlockEnum.Start, 'Start'),
|
||||
createBlock(BlockEnum.Tool, 'Tool'),
|
||||
createBlock(BlockEnum.Code, 'Code'),
|
||||
createBlock(BlockEnum.DataSource, 'Data Source'),
|
||||
],
|
||||
},
|
||||
},
|
||||
initialStoreState: {
|
||||
dataSourceList: [dataSource],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
expect(await screen.findByText('Code')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Start')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('Tool')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,95 @@
|
||||
import type { NodeDefault } from '../../types'
|
||||
import { screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
|
||||
import { BlockEnum } from '../../types'
|
||||
import NodeSelector from '../main'
|
||||
import { BlockClassificationEnum } from '../types'
|
||||
|
||||
vi.mock('reactflow', () => ({
|
||||
useStoreApi: () => ({
|
||||
getState: () => ({
|
||||
getNodes: () => [],
|
||||
}),
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_marketplace: boolean } }) => unknown) => selector({
|
||||
systemFeatures: { enable_marketplace: false },
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-plugins', () => ({
|
||||
useFeaturedToolsRecommendations: () => ({
|
||||
plugins: [],
|
||||
isLoading: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-tools', () => ({
|
||||
useAllBuiltInTools: () => ({ data: [] }),
|
||||
useAllCustomTools: () => ({ data: [] }),
|
||||
useAllWorkflowTools: () => ({ data: [] }),
|
||||
useAllMCPTools: () => ({ data: [] }),
|
||||
useInvalidateAllBuiltInTools: () => vi.fn(),
|
||||
}))
|
||||
|
||||
const createBlock = (type: BlockEnum, title: string): NodeDefault => ({
|
||||
metaData: {
|
||||
classification: BlockClassificationEnum.Default,
|
||||
sort: 0,
|
||||
type,
|
||||
title,
|
||||
author: 'Dify',
|
||||
description: `${title} description`,
|
||||
},
|
||||
defaultValue: {},
|
||||
checkValid: () => ({ isValid: true }),
|
||||
})
|
||||
|
||||
describe('NodeSelector', () => {
|
||||
it('opens with the real blocks tab, filters by search, selects a block, and clears search after close', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
renderWorkflowComponent(
|
||||
<NodeSelector
|
||||
onSelect={onSelect}
|
||||
blocks={[
|
||||
createBlock(BlockEnum.LLM, 'LLM'),
|
||||
createBlock(BlockEnum.End, 'End'),
|
||||
]}
|
||||
availableBlocksTypes={[BlockEnum.LLM, BlockEnum.End]}
|
||||
trigger={open => (
|
||||
<button type="button">
|
||||
{open ? 'selector-open' : 'selector-closed'}
|
||||
</button>
|
||||
)}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'selector-closed' }))
|
||||
|
||||
const searchInput = screen.getByPlaceholderText('workflow.tabs.searchBlock')
|
||||
expect(screen.getByText('LLM')).toBeInTheDocument()
|
||||
expect(screen.getByText('End')).toBeInTheDocument()
|
||||
|
||||
await user.type(searchInput, 'LLM')
|
||||
expect(screen.getByText('LLM')).toBeInTheDocument()
|
||||
expect(screen.queryByText('End')).not.toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByText('LLM'))
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(BlockEnum.LLM, undefined)
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByPlaceholderText('workflow.tabs.searchBlock')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'selector-closed' }))
|
||||
|
||||
const reopenedInput = screen.getByPlaceholderText('workflow.tabs.searchBlock') as HTMLInputElement
|
||||
expect(reopenedInput.value).toBe('')
|
||||
expect(screen.getByText('End')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,95 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { CollectionType } from '@/app/components/tools/types'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { Theme } from '@/types/app'
|
||||
import Tools from '../tools'
|
||||
import { ViewType } from '../view-type-select'
|
||||
import { createToolProvider } from './factories'
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useGetLanguage: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
|
||||
useMCPToolAvailability: () => ({
|
||||
allowed: true,
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockUseGetLanguage = vi.mocked(useGetLanguage)
|
||||
const mockUseTheme = vi.mocked(useTheme)
|
||||
|
||||
describe('Tools', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseGetLanguage.mockReturnValue('en_US')
|
||||
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
|
||||
})
|
||||
|
||||
it('shows the empty state when there are no tools and no search text', () => {
|
||||
render(
|
||||
<Tools
|
||||
tools={[]}
|
||||
onSelect={vi.fn()}
|
||||
viewType={ViewType.flat}
|
||||
hasSearchText={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('No tools available')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders tree groups for built-in and custom providers', () => {
|
||||
render(
|
||||
<Tools
|
||||
tools={[
|
||||
createToolProvider({
|
||||
id: 'built-in-provider',
|
||||
author: 'Built In',
|
||||
label: { en_US: 'Built In Provider', zh_Hans: 'Built In Provider' },
|
||||
}),
|
||||
createToolProvider({
|
||||
id: 'custom-provider',
|
||||
type: CollectionType.custom,
|
||||
label: { en_US: 'Custom Provider', zh_Hans: 'Custom Provider' },
|
||||
}),
|
||||
]}
|
||||
onSelect={vi.fn()}
|
||||
viewType={ViewType.tree}
|
||||
hasSearchText={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Built In')).toBeInTheDocument()
|
||||
expect(screen.getByText('workflow.tabs.customTool')).toBeInTheDocument()
|
||||
expect(screen.getByText('Built In Provider')).toBeInTheDocument()
|
||||
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows the alphabetical index in flat view when enough tools are present', () => {
|
||||
const { container } = render(
|
||||
<Tools
|
||||
tools={Array.from({ length: 11 }, (_, index) =>
|
||||
createToolProvider({
|
||||
id: `provider-${index}`,
|
||||
label: {
|
||||
en_US: `${String.fromCharCode(65 + index)} Provider`,
|
||||
zh_Hans: `${String.fromCharCode(65 + index)} Provider`,
|
||||
},
|
||||
}))}
|
||||
onSelect={vi.fn()}
|
||||
viewType={ViewType.flat}
|
||||
hasSearchText={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(container.querySelector('.index-bar')).toBeInTheDocument()
|
||||
expect(screen.getByText('A Provider')).toBeInTheDocument()
|
||||
expect(screen.getByText('K Provider')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,99 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import { CollectionType } from '@/app/components/tools/types'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { Theme } from '@/types/app'
|
||||
import { BlockEnum } from '../../../types'
|
||||
import { createTool, createToolProvider } from '../../__tests__/factories'
|
||||
import { ViewType } from '../../view-type-select'
|
||||
import Tool from '../tool'
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useGetLanguage: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
|
||||
useMCPToolAvailability: () => ({
|
||||
allowed: true,
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockUseGetLanguage = vi.mocked(useGetLanguage)
|
||||
const mockUseTheme = vi.mocked(useTheme)
|
||||
const mockTrackEvent = vi.mocked(trackEvent)
|
||||
|
||||
describe('Tool', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseGetLanguage.mockReturnValue('en_US')
|
||||
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
|
||||
})
|
||||
|
||||
it('expands a provider and selects an action item', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Tool
|
||||
payload={createToolProvider({
|
||||
tools: [
|
||||
createTool('tool-a', 'Tool A'),
|
||||
createTool('tool-b', 'Tool B'),
|
||||
],
|
||||
})}
|
||||
viewType={ViewType.flat}
|
||||
hasSearchText={false}
|
||||
onSelect={onSelect}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByText('Provider One'))
|
||||
await user.click(screen.getByText('Tool B'))
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(BlockEnum.Tool, expect.objectContaining({
|
||||
provider_id: 'provider-1',
|
||||
provider_name: 'provider-one',
|
||||
tool_name: 'tool-b',
|
||||
title: 'Tool B',
|
||||
}))
|
||||
expect(mockTrackEvent).toHaveBeenCalledWith('tool_selected', {
|
||||
tool_name: 'tool-b',
|
||||
plugin_id: 'plugin-1',
|
||||
})
|
||||
})
|
||||
|
||||
it('selects workflow tools directly without expanding the provider', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
render(
|
||||
<Tool
|
||||
payload={createToolProvider({
|
||||
type: CollectionType.workflow,
|
||||
tools: [createTool('workflow-tool', 'Workflow Tool')],
|
||||
})}
|
||||
viewType={ViewType.flat}
|
||||
hasSearchText={false}
|
||||
onSelect={onSelect}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByText('Workflow Tool'))
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(BlockEnum.Tool, expect.objectContaining({
|
||||
provider_type: CollectionType.workflow,
|
||||
tool_name: 'workflow-tool',
|
||||
tool_label: 'Workflow Tool',
|
||||
}))
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,66 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { Theme } from '@/types/app'
|
||||
import { createToolProvider } from '../../../__tests__/factories'
|
||||
import List from '../list'
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useGetLanguage: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
|
||||
useMCPToolAvailability: () => ({
|
||||
allowed: true,
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockUseGetLanguage = vi.mocked(useGetLanguage)
|
||||
const mockUseTheme = vi.mocked(useTheme)
|
||||
|
||||
describe('ToolListFlatView', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseGetLanguage.mockReturnValue('en_US')
|
||||
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
|
||||
})
|
||||
|
||||
it('assigns the first tool of each letter to the shared refs and renders the index bar', () => {
|
||||
const toolRefs = {
|
||||
current: {} as Record<string, HTMLDivElement | null>,
|
||||
}
|
||||
|
||||
render(
|
||||
<List
|
||||
letters={['A', 'B']}
|
||||
payload={[
|
||||
createToolProvider({
|
||||
id: 'provider-a',
|
||||
label: { en_US: 'A Provider', zh_Hans: 'A Provider' },
|
||||
letter: 'A',
|
||||
} as ReturnType<typeof createToolProvider>),
|
||||
createToolProvider({
|
||||
id: 'provider-b',
|
||||
label: { en_US: 'B Provider', zh_Hans: 'B Provider' },
|
||||
letter: 'B',
|
||||
} as ReturnType<typeof createToolProvider>),
|
||||
]}
|
||||
isShowLetterIndex
|
||||
indexBar={<div data-testid="index-bar" />}
|
||||
hasSearchText={false}
|
||||
onSelect={vi.fn()}
|
||||
toolRefs={toolRefs}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('A Provider')).toBeInTheDocument()
|
||||
expect(screen.getByText('B Provider')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('index-bar')).toBeInTheDocument()
|
||||
expect(toolRefs.current.A).toBeTruthy()
|
||||
expect(toolRefs.current.B).toBeTruthy()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,47 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { Theme } from '@/types/app'
|
||||
import { createToolProvider } from '../../../__tests__/factories'
|
||||
import Item from '../item'
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useGetLanguage: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
|
||||
useMCPToolAvailability: () => ({
|
||||
allowed: true,
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockUseGetLanguage = vi.mocked(useGetLanguage)
|
||||
const mockUseTheme = vi.mocked(useTheme)
|
||||
|
||||
describe('ToolListTreeView Item', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseGetLanguage.mockReturnValue('en_US')
|
||||
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
|
||||
})
|
||||
|
||||
it('renders the group heading and its provider list', () => {
|
||||
render(
|
||||
<Item
|
||||
groupName="My Group"
|
||||
toolList={[createToolProvider({
|
||||
label: { en_US: 'Provider Alpha', zh_Hans: 'Provider Alpha' },
|
||||
})]}
|
||||
hasSearchText={false}
|
||||
onSelect={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('My Group')).toBeInTheDocument()
|
||||
expect(screen.getByText('Provider Alpha')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,56 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { Theme } from '@/types/app'
|
||||
import { createToolProvider } from '../../../__tests__/factories'
|
||||
import { CUSTOM_GROUP_NAME } from '../../../index-bar'
|
||||
import List from '../list'
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useGetLanguage: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
|
||||
useMCPToolAvailability: () => ({
|
||||
allowed: true,
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockUseGetLanguage = vi.mocked(useGetLanguage)
|
||||
const mockUseTheme = vi.mocked(useTheme)
|
||||
|
||||
describe('ToolListTreeView', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseGetLanguage.mockReturnValue('en_US')
|
||||
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
|
||||
})
|
||||
|
||||
it('translates built-in special group names and renders the nested providers', () => {
|
||||
render(
|
||||
<List
|
||||
payload={{
|
||||
BuiltIn: [createToolProvider({
|
||||
label: { en_US: 'Built In Provider', zh_Hans: 'Built In Provider' },
|
||||
})],
|
||||
[CUSTOM_GROUP_NAME]: [createToolProvider({
|
||||
id: 'custom-provider',
|
||||
type: 'custom',
|
||||
label: { en_US: 'Custom Provider', zh_Hans: 'Custom Provider' },
|
||||
})],
|
||||
}}
|
||||
hasSearchText={false}
|
||||
onSelect={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('BuiltIn')).toBeInTheDocument()
|
||||
expect(screen.getByText('workflow.tabs.customTool')).toBeInTheDocument()
|
||||
expect(screen.getByText('Built In Provider')).toBeInTheDocument()
|
||||
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,91 @@
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { renderHook } from '@testing-library/react'
|
||||
import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets'
|
||||
import { DatasetsDetailContext } from '../provider'
|
||||
import { createDatasetsDetailStore, useDatasetsDetailStore } from '../store'
|
||||
|
||||
const createDataset = (id: string, name = `dataset-${id}`): DataSet => ({
|
||||
id,
|
||||
name,
|
||||
indexing_status: 'completed',
|
||||
icon_info: {
|
||||
icon: 'book',
|
||||
icon_type: 'emoji' as DataSet['icon_info']['icon_type'],
|
||||
},
|
||||
description: `${name} description`,
|
||||
permission: DatasetPermission.onlyMe,
|
||||
data_source_type: DataSourceType.FILE,
|
||||
indexing_technique: 'high_quality' as DataSet['indexing_technique'],
|
||||
created_by: 'user-1',
|
||||
updated_by: 'user-1',
|
||||
updated_at: 1,
|
||||
app_count: 0,
|
||||
doc_form: ChunkingMode.text,
|
||||
document_count: 0,
|
||||
total_document_count: 0,
|
||||
word_count: 0,
|
||||
provider: 'provider',
|
||||
embedding_model: 'model',
|
||||
embedding_model_provider: 'provider',
|
||||
embedding_available: true,
|
||||
retrieval_model_dict: {} as DataSet['retrieval_model_dict'],
|
||||
retrieval_model: {} as DataSet['retrieval_model'],
|
||||
tags: [],
|
||||
external_knowledge_info: {
|
||||
external_knowledge_id: '',
|
||||
external_knowledge_api_id: '',
|
||||
external_knowledge_api_name: '',
|
||||
external_knowledge_api_endpoint: '',
|
||||
},
|
||||
external_retrieval_model: {
|
||||
top_k: 1,
|
||||
score_threshold: 0,
|
||||
score_threshold_enabled: false,
|
||||
},
|
||||
built_in_field_enabled: false,
|
||||
runtime_mode: 'general',
|
||||
enable_api: false,
|
||||
is_multimodal: false,
|
||||
})
|
||||
|
||||
describe('datasets-detail-store store', () => {
|
||||
it('merges dataset details by id', () => {
|
||||
const store = createDatasetsDetailStore()
|
||||
|
||||
store.getState().updateDatasetsDetail([
|
||||
createDataset('dataset-1', 'Dataset One'),
|
||||
createDataset('dataset-2', 'Dataset Two'),
|
||||
])
|
||||
store.getState().updateDatasetsDetail([
|
||||
createDataset('dataset-2', 'Dataset Two Updated'),
|
||||
])
|
||||
|
||||
expect(store.getState().datasetsDetail).toMatchObject({
|
||||
'dataset-1': { name: 'Dataset One' },
|
||||
'dataset-2': { name: 'Dataset Two Updated' },
|
||||
})
|
||||
})
|
||||
|
||||
it('reads state from the datasets detail context', () => {
|
||||
const store = createDatasetsDetailStore()
|
||||
store.getState().updateDatasetsDetail([createDataset('dataset-3')])
|
||||
const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
<DatasetsDetailContext.Provider value={store}>
|
||||
{children}
|
||||
</DatasetsDetailContext.Provider>
|
||||
)
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useDatasetsDetailStore(state => state.datasetsDetail['dataset-3']?.name),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
expect(result.current).toBe('dataset-dataset-3')
|
||||
})
|
||||
|
||||
it('throws when the datasets detail provider is missing', () => {
|
||||
expect(() => renderHook(() => useDatasetsDetailStore(state => state.datasetsDetail))).toThrow(
|
||||
'Missing DatasetsDetailContext.Provider in the tree',
|
||||
)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,41 @@
|
||||
import { renderHook } from '@testing-library/react'
|
||||
import { HooksStoreContext } from '../provider'
|
||||
import { createHooksStore, useHooksStore } from '../store'
|
||||
|
||||
describe('hooks-store store', () => {
|
||||
it('creates default callbacks and refreshes selected handlers', () => {
|
||||
const store = createHooksStore({})
|
||||
const handleBackupDraft = vi.fn()
|
||||
|
||||
expect(store.getState().availableNodesMetaData).toEqual({ nodes: [] })
|
||||
expect(store.getState().hasNodeInspectVars('node-1')).toBe(false)
|
||||
expect(store.getState().getWorkflowRunAndTraceUrl('run-1')).toEqual({
|
||||
runUrl: '',
|
||||
traceUrl: '',
|
||||
})
|
||||
|
||||
store.getState().refreshAll({ handleBackupDraft })
|
||||
|
||||
expect(store.getState().handleBackupDraft).toBe(handleBackupDraft)
|
||||
})
|
||||
|
||||
it('reads state from the hooks store context', () => {
|
||||
const handleRun = vi.fn()
|
||||
const store = createHooksStore({ handleRun })
|
||||
const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
<HooksStoreContext.Provider value={store}>
|
||||
{children}
|
||||
</HooksStoreContext.Provider>
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useHooksStore(state => state.handleRun), { wrapper })
|
||||
|
||||
expect(result.current).toBe(handleRun)
|
||||
})
|
||||
|
||||
it('throws when the hooks store provider is missing', () => {
|
||||
expect(() => renderHook(() => useHooksStore(state => state.handleRun))).toThrow(
|
||||
'Missing HooksStoreContext.Provider in the tree',
|
||||
)
|
||||
})
|
||||
})
|
||||
19
web/app/components/workflow/hooks/__tests__/use-DSL.spec.ts
Normal file
19
web/app/components/workflow/hooks/__tests__/use-DSL.spec.ts
Normal file
@ -0,0 +1,19 @@
|
||||
import { renderWorkflowHook } from '../../__tests__/workflow-test-env'
|
||||
import { useDSL } from '../use-DSL'
|
||||
|
||||
describe('useDSL', () => {
|
||||
it('returns the DSL handlers from hooks store', () => {
|
||||
const exportCheck = vi.fn()
|
||||
const handleExportDSL = vi.fn()
|
||||
|
||||
const { result } = renderWorkflowHook(() => useDSL(), {
|
||||
hooksStoreProps: {
|
||||
exportCheck,
|
||||
handleExportDSL,
|
||||
},
|
||||
})
|
||||
|
||||
expect(result.current.exportCheck).toBe(exportCheck)
|
||||
expect(result.current.handleExportDSL).toBe(handleExportDSL)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,90 @@
|
||||
import { act, waitFor } from '@testing-library/react'
|
||||
import { useEdges } from 'reactflow'
|
||||
import { createEdge, createNode } from '../../__tests__/fixtures'
|
||||
import { renderWorkflowFlowHook } from '../../__tests__/workflow-test-env'
|
||||
import { NodeRunningStatus } from '../../types'
|
||||
import { useEdgesInteractionsWithoutSync } from '../use-edges-interactions-without-sync'
|
||||
|
||||
type EdgeRuntimeState = {
|
||||
_sourceRunningStatus?: NodeRunningStatus
|
||||
_targetRunningStatus?: NodeRunningStatus
|
||||
_waitingRun?: boolean
|
||||
}
|
||||
|
||||
const getEdgeRuntimeState = (edge?: { data?: unknown }): EdgeRuntimeState =>
|
||||
(edge?.data ?? {}) as EdgeRuntimeState
|
||||
|
||||
const createFlowNodes = () => [
|
||||
createNode({ id: 'a' }),
|
||||
createNode({ id: 'b' }),
|
||||
createNode({ id: 'c' }),
|
||||
]
|
||||
|
||||
const createFlowEdges = () => [
|
||||
createEdge({
|
||||
id: 'e1',
|
||||
source: 'a',
|
||||
target: 'b',
|
||||
data: {
|
||||
_sourceRunningStatus: NodeRunningStatus.Running,
|
||||
_targetRunningStatus: NodeRunningStatus.Running,
|
||||
_waitingRun: true,
|
||||
},
|
||||
}),
|
||||
createEdge({
|
||||
id: 'e2',
|
||||
source: 'b',
|
||||
target: 'c',
|
||||
data: {
|
||||
_sourceRunningStatus: NodeRunningStatus.Succeeded,
|
||||
_targetRunningStatus: undefined,
|
||||
_waitingRun: false,
|
||||
},
|
||||
}),
|
||||
]
|
||||
|
||||
const renderEdgesInteractionsHook = () =>
|
||||
renderWorkflowFlowHook(() => ({
|
||||
...useEdgesInteractionsWithoutSync(),
|
||||
edges: useEdges(),
|
||||
}), {
|
||||
nodes: createFlowNodes(),
|
||||
edges: createFlowEdges(),
|
||||
})
|
||||
|
||||
describe('useEdgesInteractionsWithoutSync', () => {
|
||||
it('clears running status and waitingRun on all edges', () => {
|
||||
const { result } = renderEdgesInteractionsHook()
|
||||
|
||||
act(() => {
|
||||
result.current.handleEdgeCancelRunningStatus()
|
||||
})
|
||||
|
||||
return waitFor(() => {
|
||||
result.current.edges.forEach((edge) => {
|
||||
const edgeState = getEdgeRuntimeState(edge)
|
||||
expect(edgeState._sourceRunningStatus).toBeUndefined()
|
||||
expect(edgeState._targetRunningStatus).toBeUndefined()
|
||||
expect(edgeState._waitingRun).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('does not mutate the original edges array', () => {
|
||||
const edges = createFlowEdges()
|
||||
const originalData = { ...getEdgeRuntimeState(edges[0]) }
|
||||
const { result } = renderWorkflowFlowHook(() => ({
|
||||
...useEdgesInteractionsWithoutSync(),
|
||||
edges: useEdges(),
|
||||
}), {
|
||||
nodes: createFlowNodes(),
|
||||
edges,
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.handleEdgeCancelRunningStatus()
|
||||
})
|
||||
|
||||
expect(getEdgeRuntimeState(edges[0])._sourceRunningStatus).toBe(originalData._sourceRunningStatus)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,114 @@
|
||||
import { createEdge, createNode } from '../../__tests__/fixtures'
|
||||
import { getNodesConnectedSourceOrTargetHandleIdsMap } from '../../utils'
|
||||
import {
|
||||
applyConnectedHandleNodeData,
|
||||
buildContextMenuEdges,
|
||||
clearEdgeMenuIfNeeded,
|
||||
clearNodeSelectionState,
|
||||
updateEdgeHoverState,
|
||||
updateEdgeSelectionState,
|
||||
} from '../use-edges-interactions.helpers'
|
||||
|
||||
vi.mock('../../utils', () => ({
|
||||
getNodesConnectedSourceOrTargetHandleIdsMap: vi.fn(),
|
||||
}))
|
||||
|
||||
const mockGetNodesConnectedSourceOrTargetHandleIdsMap = vi.mocked(getNodesConnectedSourceOrTargetHandleIdsMap)
|
||||
|
||||
describe('use-edges-interactions.helpers', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('applyConnectedHandleNodeData should merge connected handle metadata into matching nodes', () => {
|
||||
mockGetNodesConnectedSourceOrTargetHandleIdsMap.mockReturnValue({
|
||||
'node-1': {
|
||||
_connectedSourceHandleIds: ['branch-a'],
|
||||
},
|
||||
})
|
||||
|
||||
const nodes = [
|
||||
createNode({ id: 'node-1', data: { title: 'Source' } }),
|
||||
createNode({ id: 'node-2', data: { title: 'Target' } }),
|
||||
]
|
||||
const edgeChanges = [{
|
||||
type: 'add',
|
||||
edge: createEdge({ id: 'edge-1', source: 'node-1', target: 'node-2' }),
|
||||
}]
|
||||
|
||||
const result = applyConnectedHandleNodeData(nodes, edgeChanges)
|
||||
|
||||
expect(result[0].data._connectedSourceHandleIds).toEqual(['branch-a'])
|
||||
expect(result[1].data._connectedSourceHandleIds).toEqual([])
|
||||
expect(mockGetNodesConnectedSourceOrTargetHandleIdsMap).toHaveBeenCalledWith(edgeChanges, nodes)
|
||||
})
|
||||
|
||||
it('clearEdgeMenuIfNeeded should return true only when the open menu belongs to a removed edge', () => {
|
||||
expect(clearEdgeMenuIfNeeded({
|
||||
edgeMenu: { edgeId: 'edge-1' },
|
||||
edgeIds: ['edge-1', 'edge-2'],
|
||||
})).toBe(true)
|
||||
|
||||
expect(clearEdgeMenuIfNeeded({
|
||||
edgeMenu: { edgeId: 'edge-3' },
|
||||
edgeIds: ['edge-1', 'edge-2'],
|
||||
})).toBe(false)
|
||||
|
||||
expect(clearEdgeMenuIfNeeded({
|
||||
edgeIds: ['edge-1'],
|
||||
})).toBe(false)
|
||||
})
|
||||
|
||||
it('updateEdgeHoverState should toggle only the hovered edge flag', () => {
|
||||
const edges = [
|
||||
createEdge({ id: 'edge-1', data: { _hovering: false } }),
|
||||
createEdge({ id: 'edge-2', data: { _hovering: false } }),
|
||||
]
|
||||
|
||||
const result = updateEdgeHoverState(edges, 'edge-2', true)
|
||||
|
||||
expect(result.find(edge => edge.id === 'edge-1')?.data._hovering).toBe(false)
|
||||
expect(result.find(edge => edge.id === 'edge-2')?.data._hovering).toBe(true)
|
||||
})
|
||||
|
||||
it('updateEdgeSelectionState should update selected flags for select changes only', () => {
|
||||
const edges = [
|
||||
createEdge({ id: 'edge-1', selected: false }),
|
||||
createEdge({ id: 'edge-2', selected: true }),
|
||||
]
|
||||
|
||||
const result = updateEdgeSelectionState(edges, [
|
||||
{ type: 'select', id: 'edge-1', selected: true },
|
||||
{ type: 'remove', id: 'edge-2' },
|
||||
])
|
||||
|
||||
expect(result.find(edge => edge.id === 'edge-1')?.selected).toBe(true)
|
||||
expect(result.find(edge => edge.id === 'edge-2')?.selected).toBe(true)
|
||||
})
|
||||
|
||||
it('buildContextMenuEdges should select the target edge and clear bundled markers', () => {
|
||||
const edges = [
|
||||
createEdge({ id: 'edge-1', selected: true, data: { _isBundled: true } }),
|
||||
createEdge({ id: 'edge-2', selected: false, data: { _isBundled: true } }),
|
||||
]
|
||||
|
||||
const result = buildContextMenuEdges(edges, 'edge-2')
|
||||
|
||||
expect(result.find(edge => edge.id === 'edge-1')?.selected).toBe(false)
|
||||
expect(result.find(edge => edge.id === 'edge-2')?.selected).toBe(true)
|
||||
expect(result.every(edge => edge.data._isBundled === false)).toBe(true)
|
||||
})
|
||||
|
||||
it('clearNodeSelectionState should clear selected state and bundled markers on every node', () => {
|
||||
const nodes = [
|
||||
createNode({ id: 'node-1', selected: true, data: { selected: true, _isBundled: true } }),
|
||||
createNode({ id: 'node-2', selected: false, data: { selected: true, _isBundled: true } }),
|
||||
]
|
||||
|
||||
const result = clearNodeSelectionState(nodes)
|
||||
|
||||
expect(result.every(node => node.selected === false)).toBe(true)
|
||||
expect(result.every(node => node.data.selected === false)).toBe(true)
|
||||
expect(result.every(node => node.data._isBundled === false)).toBe(true)
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user