Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2026-04-09 15:36:00 +08:00
commit 5c93d74dec
37 changed files with 444 additions and 502 deletions

View File

@ -109,6 +109,7 @@ S3_BUCKET_NAME=your-bucket-name
S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
S3_REGION=your-region
S3_ADDRESS_STYLE=auto
# Workflow run and Conversation archive storage (S3-compatible)
ARCHIVE_STORAGE_ENABLED=false

View File

@ -34,9 +34,10 @@ from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService, ImportMode
from services.app_dsl_service import AppDslService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.dsl_entities import ImportMode
from services.entities.knowledge_entities.knowledge_entities import (
DataSource,
InfoList,

View File

@ -17,8 +17,9 @@ from fields.app_fields import (
)
from libs.login import current_account_with_tenant, login_required
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
from services.app_dsl_service import AppDslService
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.dsl_entities import ImportStatus
from services.feature_service import FeatureService
from .. import console_ns

View File

@ -19,7 +19,7 @@ from fields.rag_pipeline_fields import (
)
from libs.login import current_account_with_tenant, login_required
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
from services.entities.dsl_entities import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService

View File

@ -18,7 +18,8 @@ from controllers.inner_api.wraps import enterprise_inner_api_only
from extensions.ext_database import db
from models import Account, App
from models.account import AccountStatus
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
from services.app_dsl_service import AppDslService
from services.entities.dsl_entities import ImportMode, ImportStatus
class InnerAppDSLImportPayload(BaseModel):

View File

@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator
from sqlalchemy import Float, create_engine, insert, select, text
from sqlalchemy import text as sql_text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column
from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker
from configs import dify_config
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
@ -55,9 +55,8 @@ class PGVectoRS(BaseVector):
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self._client = create_engine(self._url)
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
session.commit()
self._fields: list[str] = []
class _Table(CollectionORM):
@ -88,7 +87,7 @@ class PGVectoRS(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} (
id UUID PRIMARY KEY,
@ -111,12 +110,11 @@ class PGVectoRS(BaseVector):
$$);
""")
session.execute(index_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
pks = []
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
for document, embedding in zip(documents, embeddings):
pk = uuid4()
session.execute(
@ -128,7 +126,6 @@ class PGVectoRS(BaseVector):
),
)
pks.append(pk)
session.commit()
return pks
@ -145,10 +142,9 @@ class PGVectoRS(BaseVector):
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {"ids": ids})
session.commit()
def delete_by_ids(self, ids: list[str]):
with Session(self._client) as session:
@ -159,15 +155,13 @@ class PGVectoRS(BaseVector):
if result:
ids = [item[0] for item in result]
if ids:
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {"ids": ids})
session.commit()
def delete(self):
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}"))
session.commit()
def text_exists(self, id: str) -> bool:
with Session(self._client) as session:

View File

@ -75,22 +75,27 @@ class ToolProviderApiEntity(BaseModel):
parameter.pop("input_schema", None)
# -------------
optional_fields = self.optional_field("server_url", self.server_url)
if self.type == ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
optional_fields.update(
self.optional_field(
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
match self.type:
case ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
optional_fields.update(
self.optional_field(
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
)
)
)
optional_fields.update(
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
)
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
elif self.type == ToolProviderType.WORKFLOW:
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
optional_fields.update(
self.optional_field(
"authentication", self.authentication.model_dump() if self.authentication else None
)
)
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
case ToolProviderType.WORKFLOW:
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
case _:
pass
return {
"id": self.id,
"author": self.author,

View File

@ -11,6 +11,7 @@ from uuid import uuid4
import httpx
from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type
from sqlalchemy import select
from configs import dify_config
from core.db.session_factory import session_factory
@ -166,13 +167,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with session_factory.create_session() as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
ToolFile.id == id,
)
.first()
)
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == id).limit(1))
if not tool_file:
return None
@ -190,13 +185,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with session_factory.create_session() as session:
message_file: MessageFile | None = (
session.query(MessageFile)
.where(
MessageFile.id == id,
)
.first()
)
message_file: MessageFile | None = session.scalar(select(MessageFile).where(MessageFile.id == id).limit(1))
# Check if message_file is not None
if message_file is not None:
@ -210,13 +199,7 @@ class ToolFileManager:
else:
tool_file_id = None
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
ToolFile.id == tool_file_id,
)
.first()
)
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
if not tool_file:
return None
@ -234,13 +217,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with session_factory.create_session() as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
ToolFile.id == tool_file_id,
)
.first()
)
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
if not tool_file:
return None, None

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping
from graphon.variables.input_entities import VariableEntity, VariableEntityType
from pydantic import Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@ -96,10 +97,10 @@ class WorkflowToolProviderController(ToolProviderController):
:param app: the app
:return: the tool
"""
workflow: Workflow | None = (
session.query(Workflow)
workflow: Workflow | None = session.scalar(
select(Workflow)
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first()
.limit(1)
)
if not workflow:
@ -217,13 +218,13 @@ class WorkflowToolProviderController(ToolProviderController):
return self.tools
with Session(db.engine, expire_on_commit=False) as session, session.begin():
db_provider: WorkflowToolProvider | None = (
session.query(WorkflowToolProvider)
db_provider: WorkflowToolProvider | None = session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == self.provider_id,
)
.first()
.limit(1)
)
if not db_provider:

View File

@ -813,56 +813,32 @@ class AppModelConfig(TypeBase):
"file_upload": self.file_upload_dict,
}
@staticmethod
def _dump_optional(value: Any) -> str | None:
return json.dumps(value) if value else None
def from_model_config_dict(self, model_config: AppModelConfigDict):
self.opening_statement = model_config.get("opening_statement")
self.suggested_questions = (
json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None
)
self.suggested_questions_after_answer = (
json.dumps(model_config.get("suggested_questions_after_answer"))
if model_config.get("suggested_questions_after_answer")
else None
)
self.speech_to_text = (
json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None
)
self.text_to_speech = (
json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None
)
self.more_like_this = (
json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None
)
self.sensitive_word_avoidance = (
json.dumps(model_config.get("sensitive_word_avoidance"))
if model_config.get("sensitive_word_avoidance")
else None
)
self.external_data_tools = (
json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None
)
self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None
self.user_input_form = (
json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None
self.suggested_questions = self._dump_optional(model_config.get("suggested_questions"))
self.suggested_questions_after_answer = self._dump_optional(
model_config.get("suggested_questions_after_answer")
)
self.speech_to_text = self._dump_optional(model_config.get("speech_to_text"))
self.text_to_speech = self._dump_optional(model_config.get("text_to_speech"))
self.more_like_this = self._dump_optional(model_config.get("more_like_this"))
self.sensitive_word_avoidance = self._dump_optional(model_config.get("sensitive_word_avoidance"))
self.external_data_tools = self._dump_optional(model_config.get("external_data_tools"))
self.model = self._dump_optional(model_config.get("model"))
self.user_input_form = self._dump_optional(model_config.get("user_input_form"))
self.dataset_query_variable = model_config.get("dataset_query_variable")
self.pre_prompt = model_config.get("pre_prompt")
self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None
self.retriever_resource = (
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
)
self.agent_mode = self._dump_optional(model_config.get("agent_mode"))
self.retriever_resource = self._dump_optional(model_config.get("retriever_resource"))
self.prompt_type = PromptType(model_config.get("prompt_type", "simple"))
self.chat_prompt_config = (
json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None
)
self.completion_prompt_config = (
json.dumps(model_config.get("completion_prompt_config"))
if model_config.get("completion_prompt_config")
else None
)
self.dataset_configs = (
json.dumps(model_config.get("dataset_configs")) if model_config.get("dataset_configs") else None
)
self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None
self.chat_prompt_config = self._dump_optional(model_config.get("chat_prompt_config"))
self.completion_prompt_config = self._dump_optional(model_config.get("completion_prompt_config"))
self.dataset_configs = self._dump_optional(model_config.get("dataset_configs"))
self.file_upload = self._dump_optional(model_config.get("file_upload"))
return self

View File

@ -3,7 +3,6 @@ import hashlib
import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import cast
from urllib.parse import urlparse
from uuid import uuid4
@ -19,7 +18,7 @@ from graphon.nodes.question_classifier.entities import QuestionClassifierNodeDat
from graphon.nodes.tool.entities import ToolNodeData
from packaging import version
from packaging.version import parse as parse_version
from pydantic import BaseModel, Field
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -40,6 +39,7 @@ from libs.datetime_utils import naive_utc_now
from models import Account, App, AppMode
from models.model import AppModelConfig, AppModelConfigDict, IconType
from models.workflow import Workflow
from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.workflow_draft_variable_service import WorkflowDraftVariableService
from services.workflow_service import WorkflowService
@ -53,18 +53,6 @@ DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.6.0"
class ImportMode(StrEnum):
YAML_CONTENT = "yaml-content"
YAML_URL = "yaml-url"
class ImportStatus(StrEnum):
COMPLETED = "completed"
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
PENDING = "pending"
FAILED = "failed"
class Import(BaseModel):
id: str
status: ImportStatus
@ -75,10 +63,6 @@ class Import(BaseModel):
error: str = ""
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison"""
try:

View File

@ -5,7 +5,7 @@ from typing import Any
from graphon.model_runtime.entities.provider_entities import FormType
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
@ -53,13 +53,12 @@ class DatasourceProviderService:
"""
remove oauth custom client params
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
session.query(DatasourceOauthTenantParamConfig).filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
).delete()
session.commit()
def decrypt_datasource_provider_credentials(
self,
@ -109,7 +108,7 @@ class DatasourceProviderService:
"""
get credential by id
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
if credential_id:
datasource_provider = (
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
@ -156,7 +155,6 @@ class DatasourceProviderService:
datasource_provider=datasource_provider,
)
datasource_provider.expires_at = refreshed_credentials.expires_at
session.commit()
return self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
@ -174,7 +172,7 @@ class DatasourceProviderService:
"""
get all datasource credentials by provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
datasource_providers = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
@ -224,7 +222,6 @@ class DatasourceProviderService:
provider=provider,
)
real_credentials_list.append(real_credentials)
session.commit()
return real_credentials_list
@ -234,7 +231,7 @@ class DatasourceProviderService:
"""
update datasource provider name
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
target_provider = (
session.query(DatasourceProvider)
.filter_by(
@ -266,7 +263,6 @@ class DatasourceProviderService:
raise ValueError("Authorization name is already exists")
target_provider.name = name
session.commit()
return
def set_default_datasource_provider(
@ -275,7 +271,7 @@ class DatasourceProviderService:
"""
set default datasource provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
# get provider
target_provider = (
session.query(DatasourceProvider)
@ -300,7 +296,6 @@ class DatasourceProviderService:
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
def setup_oauth_custom_client_params(
@ -315,7 +310,7 @@ class DatasourceProviderService:
"""
if client_params is None and enabled is None:
return
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
@ -349,7 +344,6 @@ class DatasourceProviderService:
if enabled is not None:
tenant_oauth_client_params.enabled = enabled
session.commit()
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
"""
@ -488,7 +482,7 @@ class DatasourceProviderService:
"""
update datasource oauth provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
with redis_client.lock(lock, timeout=20):
target_provider = (
@ -535,7 +529,6 @@ class DatasourceProviderService:
target_provider.expires_at = expire_at
target_provider.encrypted_credentials = credentials
target_provider.avatar_url = avatar_url or target_provider.avatar_url
session.commit()
def add_datasource_oauth_provider(
self,
@ -550,7 +543,7 @@ class DatasourceProviderService:
add datasource oauth provider
"""
credential_type = CredentialType.OAUTH2
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
with redis_client.lock(lock, timeout=60):
db_provider_name = name
@ -604,7 +597,6 @@ class DatasourceProviderService:
expires_at=expire_at,
)
session.add(datasource_provider)
session.commit()
def add_datasource_api_key_provider(
self,
@ -623,7 +615,7 @@ class DatasourceProviderService:
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
with redis_client.lock(lock, timeout=20):
db_provider_name = name or self.generate_next_datasource_provider_name(
@ -670,7 +662,6 @@ class DatasourceProviderService:
encrypted_credentials=credentials,
)
session.add(datasource_provider)
session.commit()
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
"""
@ -926,7 +917,7 @@ class DatasourceProviderService:
update datasource credentials.
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
@ -980,7 +971,6 @@ class DatasourceProviderService:
encrypted_credentials[key] = value
datasource_provider.encrypted_credentials = encrypted_credentials
session.commit()
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
"""

View File

@ -0,0 +1,21 @@
from enum import StrEnum
from pydantic import BaseModel, Field
from core.plugin.entities.plugin import PluginDependency
class ImportMode(StrEnum):
YAML_CONTENT = "yaml-content"
YAML_URL = "yaml-url"
class ImportStatus(StrEnum):
COMPLETED = "completed"
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
PENDING = "pending"
FAILED = "failed"
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)

View File

@ -1,3 +1,4 @@
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
@ -8,10 +9,10 @@ class PluginAutoUpgradeService:
@staticmethod
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
with sessionmaker(bind=db.engine).begin() as session:
return (
session.query(TenantPluginAutoUpgradeStrategy)
return session.scalar(
select(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
.limit(1)
)
@staticmethod
@ -24,10 +25,10 @@ class PluginAutoUpgradeService:
include_plugins: list[str],
) -> bool:
with sessionmaker(bind=db.engine).begin() as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
exist_strategy = session.scalar(
select(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
.limit(1)
)
if not exist_strategy:
strategy = TenantPluginAutoUpgradeStrategy(
@ -51,10 +52,10 @@ class PluginAutoUpgradeService:
@staticmethod
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
with sessionmaker(bind=db.engine).begin() as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
exist_strategy = session.scalar(
select(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
.limit(1)
)
if not exist_strategy:
# create for this tenant

View File

@ -1,3 +1,4 @@
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
@ -8,7 +9,9 @@ class PluginPermissionService:
@staticmethod
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
with sessionmaker(bind=db.engine).begin() as session:
return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
return session.scalar(
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
)
@staticmethod
def change_permission(
@ -17,8 +20,8 @@ class PluginPermissionService:
debug_permission: TenantPluginPermission.DebugPermission,
):
with sessionmaker(bind=db.engine).begin() as session:
permission = (
session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
permission = session.scalar(
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
)
if not permission:
permission = TenantPluginPermission(

View File

@ -555,7 +555,7 @@ class RagPipelineService:
workflow_node_execution.id
)
with Session(bind=db.engine) as session, session.begin():
with sessionmaker(bind=db.engine).begin() as session:
draft_var_saver = DraftVariableSaver(
session=session,
app_id=pipeline.id,
@ -569,7 +569,6 @@ class RagPipelineService:
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
)
session.commit()
if isinstance(workflow_node_execution_db_model, WorkflowNodeExecutionModel):
enqueue_draft_node_execution_trace(
execution=workflow_node_execution_db_model,
@ -1325,7 +1324,7 @@ class RagPipelineService:
# Convert node_execution to WorkflowNodeExecution after save
workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) # type: ignore
with Session(bind=db.engine) as session, session.begin():
with sessionmaker(bind=db.engine).begin() as session:
draft_var_saver = DraftVariableSaver(
session=session,
app_id=pipeline.id,
@ -1339,7 +1338,6 @@ class RagPipelineService:
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
)
session.commit()
enqueue_draft_node_execution_trace(
execution=workflow_node_execution_db_model,
outputs=workflow_node_execution.outputs,

View File

@ -20,7 +20,7 @@ from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeDat
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
from graphon.nodes.tool.entities import ToolNodeData
from packaging import version
from pydantic import BaseModel, Field
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -37,7 +37,7 @@ from models import Account
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
from models.enums import CollectionBindingType, DatasetRuntimeMode
from models.workflow import Workflow, WorkflowType
from services.app_dsl_service import ImportMode, ImportStatus
from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus
from services.entities.knowledge_entities.rag_pipeline_entities import (
IconInfo,
KnowledgeConfiguration,
@ -64,10 +64,6 @@ class RagPipelineImportInfo(BaseModel):
dataset_id: str | None = None
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison"""
try:

View File

@ -5,7 +5,7 @@ from pathlib import Path
from typing import Any
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
@ -46,13 +46,12 @@ class BuiltinToolManageService:
delete custom oauth client params
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
session.query(ToolOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
).delete()
session.commit()
return {"result": "success"}
@staticmethod
@ -150,7 +149,7 @@ class BuiltinToolManageService:
"""
update builtin tool provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
# get if the provider exists
db_provider = (
session.query(BuiltinToolProvider)
@ -203,9 +202,7 @@ class BuiltinToolManageService:
db_provider.name = name
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@ -222,7 +219,7 @@ class BuiltinToolManageService:
"""
add builtin tool provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
try:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
@ -281,9 +278,7 @@ class BuiltinToolManageService:
)
session.add(db_provider)
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@ -379,7 +374,7 @@ class BuiltinToolManageService:
"""
delete tool provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
db_provider = (
session.query(BuiltinToolProvider)
.where(
@ -393,7 +388,6 @@ class BuiltinToolManageService:
raise ValueError(f"you have not added provider {provider}")
session.delete(db_provider)
session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
@ -409,7 +403,7 @@ class BuiltinToolManageService:
"""
set default provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
# get provider
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
if target_provider is None:
@ -422,7 +416,6 @@ class BuiltinToolManageService:
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
@ -654,7 +647,7 @@ class BuiltinToolManageService:
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
custom_client_params = (
session.query(ToolOAuthTenantClient)
.filter_by(
@ -690,7 +683,6 @@ class BuiltinToolManageService:
if enable_oauth_custom_client is not None:
custom_client_params.enabled = enable_oauth_custom_client
session.commit()
return {"result": "success"}
@staticmethod

View File

@ -48,21 +48,25 @@ class ToolTransformService:
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
)
if provider_type == ToolProviderType.BUILT_IN:
return str(url_prefix / "builtin" / provider_name / "icon")
elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}:
try:
if isinstance(icon, str):
parsed = emoji_icon_adapter.validate_json(icon)
return {"background": parsed["background"], "content": parsed["content"]}
return {"background": icon["background"], "content": icon["content"]}
except (ValueError, ValidationError, KeyError):
return {"background": "#252525", "content": "\ud83d\ude01"}
elif provider_type == ToolProviderType.MCP:
if isinstance(icon, Mapping):
return {"background": icon.get("background", ""), "content": icon.get("content", "")}
return icon
return ""
match provider_type:
case ToolProviderType.BUILT_IN:
return str(url_prefix / "builtin" / provider_name / "icon")
case ToolProviderType.API | ToolProviderType.WORKFLOW:
try:
if isinstance(icon, str):
parsed = emoji_icon_adapter.validate_json(icon)
return {"background": parsed["background"], "content": parsed["content"]}
return {"background": icon["background"], "content": icon["content"]}
except (ValueError, ValidationError, KeyError):
return {"background": "#252525", "content": "\ud83d\ude01"}
case ToolProviderType.MCP:
if isinstance(icon, Mapping):
return {"background": icon.get("background", ""), "content": icon.get("content", "")}
return icon
case ToolProviderType.PLUGIN | ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL:
return ""
case _:
return ""
@staticmethod
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):

View File

@ -1075,9 +1075,8 @@ class DraftVariableSaver:
)
engine = bind = self._session.get_bind()
assert isinstance(engine, Engine)
with Session(bind=engine, expire_on_commit=False) as session:
with sessionmaker(bind=engine, expire_on_commit=False).begin() as session:
session.add(variable_file)
session.commit()
return truncation_result.result, variable_file

View File

@ -837,7 +837,7 @@ class WorkflowService:
with sessionmaker(db.engine).begin() as session:
outputs = workflow_node_execution.load_full_outputs(session, storage)
with Session(bind=db.engine) as session, session.begin():
with sessionmaker(bind=db.engine).begin() as session:
draft_var_saver = DraftVariableSaver(
session=session,
app_id=app_model.id,
@ -848,7 +848,6 @@ class WorkflowService:
user=account,
)
draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs)
session.commit()
enqueue_draft_node_execution_trace(
execution=workflow_node_execution,
@ -977,7 +976,7 @@ class WorkflowService:
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None
with Session(bind=db.engine) as session, session.begin():
with sessionmaker(bind=db.engine).begin() as session:
draft_var_saver = DraftVariableSaver(
session=session,
app_id=app_model.id,
@ -988,7 +987,6 @@ class WorkflowService:
enclosing_node_id=enclosing_node_id,
)
draft_var_saver.save(outputs=outputs, process_data={})
session.commit()
return outputs

View File

@ -112,7 +112,9 @@ def clean_dataset_task(
segment_ids = [segment.id for segment in segments]
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
).all()
for image_file in image_files:
if image_file is None:
continue
@ -150,20 +152,22 @@ def clean_dataset_task(
)
session.execute(binding_delete_stmt)
session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
session.execute(delete(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id))
session.execute(delete(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id))
session.execute(delete(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id))
# delete dataset metadata
session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
session.execute(delete(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id))
session.execute(delete(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id))
# delete pipeline and workflow
if pipeline_id:
session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
session.query(Workflow).where(
Workflow.tenant_id == tenant_id,
Workflow.app_id == pipeline_id,
Workflow.type == WorkflowType.RAG_PIPELINE,
).delete()
session.execute(delete(Pipeline).where(Pipeline.id == pipeline_id))
session.execute(
delete(Workflow).where(
Workflow.tenant_id == tenant_id,
Workflow.app_id == pipeline_id,
Workflow.type == WorkflowType.RAG_PIPELINE,
)
)
# delete files
if documents:
file_ids = []
@ -174,7 +178,7 @@ def clean_dataset_task(
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
file_ids.append(file_id)
files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
storage.delete(file.key)

View File

@ -3,7 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
@ -29,7 +29,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise Exception("Dataset not found")
@ -49,23 +49,24 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
)
session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if segments:
documents = []
for segment in segments:
@ -82,13 +83,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
)
session.commit()
elif action == "update":
@ -104,8 +109,10 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
)
session.commit()
@ -115,15 +122,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if segments:
documents = []
multimodal_documents = []
@ -172,13 +178,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
)
session.commit()
else:

View File

@ -32,7 +32,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
tenant_id = None
with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
@ -42,7 +44,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
return
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise Exception("Dataset not found")
@ -87,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
@ -112,7 +114,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
try:
index_processor = IndexProcessorFactory(index_type).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if dataset:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
@ -120,7 +122,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.exception("Failed to clean vector index for document %s", document_id)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if not document:
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
return
@ -140,7 +142,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
try:
indexing_runner = IndexingRunner()
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if document:
indexing_runner.run([document])
end_at = time.perf_counter()
@ -150,7 +152,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
except Exception as e:
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)

View File

@ -53,6 +53,31 @@ def _session_factory(calls, execute_results=None):
return _session
class _FakeBeginContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return None
def _sessionmaker_factory(calls, execute_results=None):
def _sessionmaker(*args, **kwargs):
session = _FakeSessionContext(calls=calls, execute_results=execute_results)
return MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
return _sessionmaker
def _patch_both(monkeypatch, module, calls, execute_results=None):
"""Patch both Session and sessionmaker on the module with the same call tracker."""
monkeypatch.setattr(module, "Session", _session_factory(calls, execute_results))
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(calls, execute_results))
@pytest.fixture
def pgvecto_module(monkeypatch):
for name, module in _build_fake_pgvecto_modules().items():
@ -105,7 +130,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
_patch_both(monkeypatch, module, session_calls)
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
vector.create_collection = MagicMock()
@ -124,7 +149,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
_patch_both(monkeypatch, module, session_calls)
lock = MagicMock()
lock.__enter__.return_value = None
@ -151,10 +176,10 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])]
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
_patch_both(monkeypatch, module, init_calls)
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results)))
_patch_both(monkeypatch, module, runtime_calls, execute_results=list(execute_results))
class _InsertBuilder:
def __init__(self, table):
@ -179,6 +204,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
"Session",
_session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]),
)
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
monkeypatch.setattr(
@ -204,12 +230,13 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
],
),
)
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
vector.delete_by_ids(["doc-1"])
assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls)
assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls)
runtime_calls.clear()
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()]))
_patch_both(monkeypatch, module, runtime_calls, execute_results=[MagicMock()])
vector.delete()
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
@ -218,7 +245,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
init_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
_patch_both(monkeypatch, module, init_calls)
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
runtime_calls = []
@ -277,7 +304,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
(SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1),
(SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8),
]
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows]))
_patch_both(monkeypatch, module, runtime_calls, execute_results=[rows])
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
assert len(docs) == 1

View File

@ -129,7 +129,7 @@ def test_get_file_binary_returns_none_when_not_found() -> None:
# Arrange
manager = ToolFileManager()
session = Mock()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
# Act
with _patch_session_factory(session):
@ -144,7 +144,7 @@ def test_get_file_binary_returns_bytes_when_found() -> None:
manager = ToolFileManager()
tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain")
session = Mock()
session.query.return_value.where.return_value.first.return_value = tool_file
session.scalar.return_value = tool_file
# Act
with patch("core.tools.tool_file_manager.storage") as storage:
@ -160,11 +160,7 @@ def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None:
# Arrange
manager = ToolFileManager()
session = Mock()
first_query = Mock()
second_query = Mock()
first_query.where.return_value.first.return_value = None
second_query.where.return_value.first.return_value = None
session.query.side_effect = [first_query, second_query]
session.scalar.side_effect = [None, None]
# Act
with _patch_session_factory(session):
@ -179,11 +175,7 @@ def test_get_file_binary_by_message_file_id_when_url_is_none() -> None:
manager = ToolFileManager()
message_file = SimpleNamespace(url=None)
session = Mock()
first_query = Mock()
second_query = Mock()
first_query.where.return_value.first.return_value = message_file
second_query.where.return_value.first.return_value = None
session.query.side_effect = [first_query, second_query]
session.scalar.side_effect = [message_file, None]
# Act
with _patch_session_factory(session):
@ -199,11 +191,7 @@ def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None:
message_file = SimpleNamespace(url="https://x/files/tools/tool123.png")
tool_file = SimpleNamespace(file_key="k2", mimetype="image/png")
session = Mock()
first_query = Mock()
second_query = Mock()
first_query.where.return_value.first.return_value = message_file
second_query.where.return_value.first.return_value = tool_file
session.query.side_effect = [first_query, second_query]
session.scalar.side_effect = [message_file, tool_file]
# Act
with patch("core.tools.tool_file_manager.storage") as storage:
@ -219,7 +207,7 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None:
# Arrange
manager = ToolFileManager()
session = Mock()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
# Act
with _patch_session_factory(session):
@ -242,7 +230,7 @@ def test_get_file_generator_returns_stream_when_found() -> None:
size=12,
)
session = Mock()
session.query.return_value.where.return_value.first.return_value = tool_file
session.scalar.return_value = tool_file
# Act
with patch("core.tools.tool_file_manager.storage") as storage:

View File

@ -43,7 +43,7 @@ def test_get_db_provider_tool_builds_entity():
controller = _controller()
session = Mock()
workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={})
session.query.return_value.where.return_value.first.return_value = workflow
session.scalar.return_value = workflow
app = SimpleNamespace(id="app-1")
db_provider = SimpleNamespace(
id="provider-1",
@ -136,7 +136,7 @@ def test_from_db_builds_controller():
parameter_configurations=[],
)
session = _mock_session_with_begin()
session.query.return_value.where.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
session.get.side_effect = [app, user]
fake_cm = MagicMock()
fake_cm.__enter__.return_value = session
@ -163,7 +163,7 @@ def test_get_tools_returns_empty_when_provider_missing():
mock_db.engine = object()
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
session = _mock_session_with_begin()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
session_cls.return_value.__enter__.return_value = session
assert controller.get_tools("tenant-1") == []
@ -189,7 +189,7 @@ def test_get_tools_raises_when_app_missing():
mock_db.engine = object()
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
session = _mock_session_with_begin()
session.query.return_value.where.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
session.get.return_value = None
session_cls.return_value.__enter__.return_value = session
with pytest.raises(ValueError, match="app not found"):

View File

@ -20,7 +20,7 @@ class TestGetStrategy:
def test_returns_strategy_when_found(self):
p1, p2, session = _patched_session()
strategy = MagicMock()
session.query.return_value.where.return_value.first.return_value = strategy
session.scalar.return_value = strategy
with p1, p2:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
@ -31,7 +31,7 @@ class TestGetStrategy:
def test_returns_none_when_not_found(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with p1, p2:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
@ -44,9 +44,9 @@ class TestGetStrategy:
class TestChangeStrategy:
def test_creates_new_strategy(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.return_value = MagicMock()
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
@ -65,7 +65,7 @@ class TestChangeStrategy:
def test_updates_existing_strategy(self):
p1, p2, session = _patched_session()
existing = MagicMock()
session.query.return_value.where.return_value.first.return_value = existing
session.scalar.return_value = existing
with p1, p2:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
@ -90,11 +90,12 @@ class TestChangeStrategy:
class TestExcludePlugin:
def test_creates_default_strategy_when_none_exists(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with (
p1,
p2,
patch(f"{MODULE}.select"),
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
):
@ -113,9 +114,9 @@ class TestExcludePlugin:
existing = MagicMock()
existing.upgrade_mode = "exclude"
existing.exclude_plugins = ["p-existing"]
session.query.return_value.where.return_value.first.return_value = existing
session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
@ -131,9 +132,9 @@ class TestExcludePlugin:
existing = MagicMock()
existing.upgrade_mode = "partial"
existing.include_plugins = ["p1", "p2"]
session.query.return_value.where.return_value.first.return_value = existing
session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
@ -148,9 +149,9 @@ class TestExcludePlugin:
p1, p2, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "all"
session.query.return_value.where.return_value.first.return_value = existing
session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
@ -167,9 +168,9 @@ class TestExcludePlugin:
existing = MagicMock()
existing.upgrade_mode = "exclude"
existing.exclude_plugins = ["p1"]
session.query.return_value.where.return_value.first.return_value = existing
session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"

View File

@ -20,7 +20,7 @@ class TestGetPermission:
def test_returns_permission_when_found(self):
p1, p2, session = _patched_session()
permission = MagicMock()
session.query.return_value.where.return_value.first.return_value = permission
session.scalar.return_value = permission
with p1, p2:
from services.plugin.plugin_permission_service import PluginPermissionService
@ -31,7 +31,7 @@ class TestGetPermission:
def test_returns_none_when_not_found(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with p1, p2:
from services.plugin.plugin_permission_service import PluginPermissionService
@ -44,9 +44,9 @@ class TestGetPermission:
class TestChangePermission:
def test_creates_new_permission_when_not_exists(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with p1, p2, patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
perm_cls.return_value = MagicMock()
from services.plugin.plugin_permission_service import PluginPermissionService
@ -59,7 +59,7 @@ class TestChangePermission:
def test_updates_existing_permission(self):
p1, p2, session = _patched_session()
existing = MagicMock()
session.query.return_value.where.return_value.first.return_value = existing
session.scalar.return_value = existing
with p1, p2:
from services.plugin.plugin_permission_service import PluginPermissionService

View File

@ -40,7 +40,10 @@ class TestDatasourceProviderService:
q returns itself for .filter_by(), .order_by(), .where() so any
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
"""
with patch("services.datasource_provider_service.Session") as mock_cls:
with (
patch("services.datasource_provider_service.Session") as mock_cls,
patch("services.datasource_provider_service.sessionmaker") as mock_sm,
):
sess = MagicMock(spec=Session)
q = MagicMock()
@ -63,6 +66,8 @@ class TestDatasourceProviderService:
mock_cls.return_value.__enter__.return_value = sess
mock_cls.return_value.no_autoflush.__enter__.return_value = sess
mock_sm.return_value.begin.return_value.__enter__.return_value = sess
mock_sm.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
yield sess
@ -266,7 +271,6 @@ class TestDatasourceProviderService:
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}),
):
service.get_datasource_credentials("t1", "prov", "org/plug")
mock_db_session.commit.assert_called_once()
def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user):
"""API key credentials with expires_at=-1 skip refresh and return directly."""
@ -333,7 +337,6 @@ class TestDatasourceProviderService:
p.name = "same"
mock_db_session.query().first.return_value = p
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
mock_db_session.commit.assert_not_called()
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
@ -352,7 +355,6 @@ class TestDatasourceProviderService:
mock_db_session.query().count.return_value = 0
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
assert p.name == "new_name"
mock_db_session.commit.assert_called_once()
# -----------------------------------------------------------------------
# set_default_datasource_provider (lines 277-303)
@ -370,7 +372,6 @@ class TestDatasourceProviderService:
mock_db_session.query().first.return_value = target
service.set_default_datasource_provider("t1", make_id(), "new-id")
assert target.is_default is True
mock_db_session.commit.assert_called_once()
# -----------------------------------------------------------------------
# get_oauth_encrypter (lines 404-420)
@ -460,7 +461,6 @@ class TestDatasourceProviderService:
with patch.object(service, "extract_secret_variables", return_value=[]):
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
"""Conflict on name results in auto-incremented name, not an error."""
@ -512,7 +512,6 @@ class TestDatasourceProviderService:
mock_db_session.query().count.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=[]):
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
mock_db_session.commit.assert_called_once()
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
@ -523,7 +522,6 @@ class TestDatasourceProviderService:
service.reauthorize_datasource_oauth_provider(
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
)
mock_db_session.commit.assert_called_once()
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
@ -571,7 +569,6 @@ class TestDatasourceProviderService:
):
service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"})
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
mock_db_session.query().count.return_value = 0
@ -747,7 +744,6 @@ class TestDatasourceProviderService:
# encrypter must have been called with the new secret value
self._enc.encrypt_token.assert_called()
# commit must be called exactly once
mock_db_session.commit.assert_called_once()
# -----------------------------------------------------------------------
# remove_datasource_credentials (lines 980-997)
@ -758,7 +754,6 @@ class TestDatasourceProviderService:
mock_db_session.scalar.return_value = p
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
mock_db_session.delete.assert_called_once_with(p)
mock_db_session.commit.assert_called_once()
def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session):
"""No error raised; no delete called when record doesn't exist (lines 994 branch)."""

View File

@ -15,17 +15,24 @@ def _mock_session(mock_session_cls):
return session
def _mock_sessionmaker(mock_sm_cls):
"""Helper: set up a sessionmaker().begin() context manager mock and return the inner session."""
session = MagicMock()
mock_sm_cls.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
mock_sm_cls.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
return session
class TestDeleteCustomOauthClientParams:
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db")
def test_deletes_and_returns_success(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
def test_deletes_and_returns_success(self, mock_db, mock_sm_cls):
session = _mock_sessionmaker(mock_sm_cls)
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
assert result == {"result": "success"}
session.query.return_value.filter_by.return_value.delete.assert_called_once()
session.commit.assert_called_once()
class TestListBuiltinToolProviderTools:
@ -138,10 +145,10 @@ class TestIsOauthCustomClientEnabled:
class TestDeleteBuiltinToolProvider:
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
@patch(f"{MODULE}.ToolManager")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db")
def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc):
session = _mock_session(mock_session_cls)
def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
session = _mock_sessionmaker(mock_sm_cls)
session.query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError, match="you have not added provider"):
@ -149,10 +156,10 @@ class TestDeleteBuiltinToolProvider:
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
@patch(f"{MODULE}.ToolManager")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db")
def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc):
session = _mock_session(mock_session_cls)
def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
session = _mock_sessionmaker(mock_sm_cls)
db_provider = MagicMock()
session.query.return_value.where.return_value.first.return_value = db_provider
mock_cache = MagicMock()
@ -162,24 +169,23 @@ class TestDeleteBuiltinToolProvider:
assert result == {"result": "success"}
session.delete.assert_called_once_with(db_provider)
session.commit.assert_called_once()
mock_cache.delete.assert_called_once()
class TestSetDefaultProvider:
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db")
def test_raises_when_not_found(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
def test_raises_when_not_found(self, mock_db, mock_sm_cls):
session = _mock_sessionmaker(mock_sm_cls)
session.query.return_value.filter_by.return_value.first.return_value = None
with pytest.raises(ValueError, match="provider not found"):
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db")
def test_sets_default_and_clears_old(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls):
session = _mock_sessionmaker(mock_sm_cls)
target = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = target
@ -187,14 +193,13 @@ class TestSetDefaultProvider:
assert result == {"result": "success"}
assert target.is_default is True
session.commit.assert_called_once()
class TestUpdateBuiltinToolProvider:
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db")
def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls):
session = _mock_sessionmaker(mock_sm_cls)
session.query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError, match="you have not added provider"):
@ -203,10 +208,10 @@ class TestUpdateBuiltinToolProvider:
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
@patch(f"{MODULE}.CredentialType")
@patch(f"{MODULE}.ToolManager")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db")
def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc):
session = _mock_session(mock_session_cls)
def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc):
session = _mock_sessionmaker(mock_sm_cls)
db_provider = MagicMock(credential_type="api_key", credentials="{}")
session.query.return_value.where.return_value.first.return_value = db_provider
@ -227,7 +232,6 @@ class TestUpdateBuiltinToolProvider:
result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"})
assert result == {"result": "success"}
session.commit.assert_called_once()
mock_cache.delete.assert_called_once()

View File

@ -60,12 +60,6 @@ def mock_db_session():
cm.__exit__.return_value = None
mock_sf.create_session.return_value = cm
# Setup query chain
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.delete.return_value = 0
# Setup scalars for select queries
mock_session.scalars.return_value.all.return_value = []
@ -220,11 +214,6 @@ class TestPipelineAndWorkflowDeletion:
- Pipeline record is deleted
- Related workflow record is deleted
"""
# Arrange
mock_query = mock_db_session.session.query.return_value
mock_query.where.return_value = mock_query
mock_query.delete.return_value = 1
# Act
clean_dataset_task(
dataset_id=dataset_id,
@ -236,9 +225,9 @@ class TestPipelineAndWorkflowDeletion:
pipeline_id=pipeline_id,
)
# Assert - verify delete was called for pipeline-related queries
# The actual count depends on total queries, but pipeline deletion should add 2 more
assert mock_query.delete.call_count >= 7 # 5 base + 2 pipeline/workflow
# Assert - verify execute was called for delete operations
# 1 attachment JOIN query + 5 base deletes + 2 pipeline/workflow deletes = 8
assert mock_db_session.session.execute.call_count >= 8
def test_clean_dataset_task_without_pipeline_id(
self,
@ -256,11 +245,6 @@ class TestPipelineAndWorkflowDeletion:
Expected behavior:
- Pipeline and workflow deletion queries are not executed
"""
# Arrange
mock_query = mock_db_session.session.query.return_value
mock_query.where.return_value = mock_query
mock_query.delete.return_value = 1
# Act
clean_dataset_task(
dataset_id=dataset_id,
@ -272,8 +256,9 @@ class TestPipelineAndWorkflowDeletion:
pipeline_id=None,
)
# Assert - verify delete was called only for base queries (5 times)
assert mock_query.delete.call_count == 5
# Assert - verify execute was called for delete operations
# 1 attachment JOIN query + 5 base deletes = 6
assert mock_db_session.session.execute.call_count == 6
# ============================================================================

View File

@ -80,7 +80,7 @@ def mock_db_session(mock_document, mock_dataset):
with patch("tasks.document_indexing_sync_task.session_factory", autospec=True) as mock_session_factory:
session = MagicMock()
session.scalars.return_value.all.return_value = []
session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
session.scalar.side_effect = [mock_document, mock_dataset]
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
@ -242,14 +242,13 @@ class TestDataSourceInfoSerialization:
# DB session mock — shared across all ``session_factory.create_session()`` calls
session = MagicMock()
session.scalars.return_value.all.return_value = []
# .where() path: session 1 reads document + dataset, session 2 reads dataset
session.query.return_value.where.return_value.first.side_effect = [
# All .first() calls are now session.scalar() — ordered by call sequence:
# session 1: document + dataset, session 2: dataset (clean), session 3: document (update),
# session 4: document (indexing)
session.scalar.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
]
# .filter_by() path: session 3 (update), session 4 (indexing)
session.query.return_value.filter_by.return_value.first.side_effect = [
mock_document,
mock_document,
]

View File

@ -469,6 +469,7 @@ S3_REGION=us-east-1
S3_BUCKET_NAME=difyai
S3_ACCESS_KEY=
S3_SECRET_KEY=
S3_ADDRESS_STYLE=auto
# Whether to use AWS managed IAM roles for authenticating with the S3 service.
# If set to false, the access key and secret key must be provided.
S3_USE_AWS_MANAGED_IAM=false

View File

@ -131,6 +131,7 @@ x-shared-env: &shared-api-worker-env
S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai}
S3_ACCESS_KEY: ${S3_ACCESS_KEY:-}
S3_SECRET_KEY: ${S3_SECRET_KEY:-}
S3_ADDRESS_STYLE: ${S3_ADDRESS_STYLE:-auto}
S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false}
ARCHIVE_STORAGE_ENABLED: ${ARCHIVE_STORAGE_ENABLED:-false}
ARCHIVE_STORAGE_ENDPOINT: ${ARCHIVE_STORAGE_ENDPOINT:-}

View File

@ -1,14 +1,13 @@
/* eslint-disable ts/no-explicit-any */
import type { ReactNode } from 'react'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import { act, fireEvent, screen, waitFor } from '@testing-library/react'
import { renderWithNuqs } from '@/test/nuqs-testing'
import { AppModeEnum } from '@/types/app'
import ConversationList from '../list'
const mockFetchChatMessages = vi.fn()
const mockUpdateLogMessageFeedbacks = vi.fn()
const mockUpdateLogMessageAnnotations = vi.fn()
const mockPush = vi.fn()
const mockReplace = vi.fn()
const mockOnRefresh = vi.fn()
const mockSetCurrentLogItem = vi.fn()
const mockSetShowPromptLogModal = vi.fn()
@ -17,7 +16,6 @@ const mockSetShowMessageLogModal = vi.fn()
const mockCompletionRefetch = vi.fn()
const mockDelAnnotation = vi.fn()
let mockSearchParams = new URLSearchParams()
let mockChatConversationDetail: Record<string, unknown> | undefined
let mockCompletionConversationDetail: Record<string, unknown> | undefined
let mockShowMessageLogModal = false
@ -53,18 +51,6 @@ vi.mock('@/hooks/use-breakpoints', () => ({
},
}))
vi.mock('@/next/navigation', () => ({
useRouter: () => ({
push: mockPush,
replace: mockReplace,
}),
usePathname: () => '/apps/app-1/logs',
useSearchParams: () => ({
get: (key: string) => mockSearchParams.get(key),
toString: () => mockSearchParams.toString(),
}),
}))
vi.mock('@/service/use-log', () => ({
useChatConversationDetail: () => ({
data: mockChatConversationDetail,
@ -256,10 +242,28 @@ const createChatMessage = (id: string, overrides: Record<string, unknown> = {})
...overrides,
})
const renderConversationList = ({
appDetail = { id: 'app-1', mode: AppModeEnum.CHAT } as any,
logs = createLogs() as any,
searchParams = '?page=2',
}: {
appDetail?: any
logs?: any
searchParams?: string
} = {}) => {
return renderWithNuqs(
<ConversationList
appDetail={appDetail}
logs={logs}
onRefresh={mockOnRefresh}
/>,
{ searchParams },
)
}
describe('ConversationList', () => {
beforeEach(() => {
vi.clearAllMocks()
mockSearchParams = new URLSearchParams('page=2')
mockChatConversationDetail = undefined
mockCompletionConversationDetail = undefined
mockShowMessageLogModal = false
@ -273,34 +277,29 @@ describe('ConversationList', () => {
})
})
it('should render chat rows and push the conversation id into the url when a row is clicked', () => {
render(
<ConversationList
appDetail={{ id: 'app-1', mode: AppModeEnum.CHAT } as any}
logs={createLogs() as any}
onRefresh={mockOnRefresh}
/>,
)
it('should render chat rows and push the conversation id into the url when a row is clicked', async () => {
const { onUrlUpdate } = renderConversationList()
expect(screen.getByText('hello world')).toBeInTheDocument()
expect(screen.getAllByText('formatted-1710000000')).toHaveLength(2)
fireEvent.click(screen.getByText('hello world'))
expect(mockPush).toHaveBeenCalledWith('/apps/app-1/logs?page=2&conversation_id=conversation-1', { scroll: false })
expect(screen.getByTestId('drawer')).toBeInTheDocument()
await waitFor(() => {
expect(onUrlUpdate).toHaveBeenCalled()
expect(screen.getByTestId('drawer')).toBeInTheDocument()
})
const update = onUrlUpdate.mock.calls.at(-1)![0]
expect(update.searchParams.get('page')).toBe('2')
expect(update.searchParams.get('conversation_id')).toBe('conversation-1')
expect(update.options.history).toBe('push')
})
it('should close the drawer, refresh, and clear modal flags', () => {
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
render(
<ConversationList
appDetail={{ id: 'app-1', mode: AppModeEnum.CHAT } as any}
logs={createLogs() as any}
onRefresh={mockOnRefresh}
/>,
)
it('should close the drawer, refresh, and clear modal flags', async () => {
const { onUrlUpdate } = renderConversationList({
searchParams: '?page=2&conversation_id=conversation-1',
})
fireEvent.click(screen.getByText('close-drawer'))
@ -308,11 +307,18 @@ describe('ConversationList', () => {
expect(mockSetShowPromptLogModal).toHaveBeenCalledWith(false)
expect(mockSetShowAgentLogModal).toHaveBeenCalledWith(false)
expect(mockSetShowMessageLogModal).toHaveBeenCalledWith(false)
expect(mockReplace).toHaveBeenCalledWith('/apps/app-1/logs?page=2', { scroll: false })
await waitFor(() => {
expect(onUrlUpdate).toHaveBeenCalled()
})
const update = onUrlUpdate.mock.calls.at(-1)![0]
expect(update.searchParams.get('page')).toBe('2')
expect(update.searchParams.has('conversation_id')).toBe(false)
expect(update.options.history).toBe('replace')
})
it('should render chat conversation details and submit feedback from the chat panel', async () => {
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
mockChatConversationDetail = {
id: 'conversation-1',
created_at: 1710000000,
@ -355,13 +361,9 @@ describe('ConversationList', () => {
mockShowMessageLogModal = true
mockCurrentLogItem = { id: 'log-1' }
render(
<ConversationList
appDetail={{ id: 'app-1', mode: AppModeEnum.CHAT } as any}
logs={createLogs() as any}
onRefresh={mockOnRefresh}
/>,
)
renderConversationList({
searchParams: '?page=2&conversation_id=conversation-1',
})
await waitFor(() => {
expect(mockFetchChatMessages).toHaveBeenCalledWith({
@ -396,7 +398,6 @@ describe('ConversationList', () => {
})
it('should render completion details and refetch after feedback updates', async () => {
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
mockCompletionConversationDetail = {
id: 'conversation-1',
created_at: 1710000000,
@ -423,13 +424,11 @@ describe('ConversationList', () => {
mockShowPromptLogModal = true
mockCurrentLogItem = { id: 'log-2' }
render(
<ConversationList
appDetail={{ id: 'app-1', mode: AppModeEnum.COMPLETION } as any}
logs={createCompletionLogs() as any}
onRefresh={mockOnRefresh}
/>,
)
renderConversationList({
appDetail: { id: 'app-1', mode: AppModeEnum.COMPLETION } as any,
logs: createCompletionLogs() as any,
searchParams: '?page=2&conversation_id=conversation-1',
})
await waitFor(() => {
expect(screen.getByTestId('text-generation')).toBeInTheDocument()
@ -454,64 +453,61 @@ describe('ConversationList', () => {
})
it('should render chatflow status cells and feedback counters for advanced chat logs', () => {
render(
<ConversationList
appDetail={{ id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT } as any}
logs={{
data: [
{
id: 'conversation-pending',
name: 'Pending row',
from_account_name: 'user-a',
read_at: 1710000001,
message_count: 3,
status_count: { paused: 1, success: 0, failed: 0, partial_success: 0 },
user_feedback_stats: { like: 2, dislike: 0 },
admin_feedback_stats: { like: 0, dislike: 1 },
updated_at: 1710000000,
created_at: 1710000000,
},
{
id: 'conversation-success',
name: 'Success row',
from_account_name: 'user-b',
read_at: 1710000001,
message_count: 4,
status_count: { paused: 0, success: 4, failed: 0, partial_success: 0 },
user_feedback_stats: { like: 0, dislike: 0 },
admin_feedback_stats: { like: 0, dislike: 0 },
updated_at: 1710000000,
created_at: 1710000000,
},
{
id: 'conversation-partial',
name: 'Partial row',
from_account_name: 'user-c',
read_at: 1710000001,
message_count: 5,
status_count: { paused: 0, success: 3, failed: 0, partial_success: 1 },
user_feedback_stats: { like: 0, dislike: 0 },
admin_feedback_stats: { like: 0, dislike: 0 },
updated_at: 1710000000,
created_at: 1710000000,
},
{
id: 'conversation-failure',
name: 'Failure row',
from_account_name: 'user-d',
read_at: 1710000001,
message_count: 1,
status_count: { paused: 0, success: 0, failed: 2, partial_success: 0 },
user_feedback_stats: { like: 0, dislike: 0 },
admin_feedback_stats: { like: 0, dislike: 0 },
updated_at: 1710000000,
created_at: 1710000000,
},
],
} as any}
onRefresh={mockOnRefresh}
/>,
)
renderConversationList({
appDetail: { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT } as any,
logs: {
data: [
{
id: 'conversation-pending',
name: 'Pending row',
from_account_name: 'user-a',
read_at: 1710000001,
message_count: 3,
status_count: { paused: 1, success: 0, failed: 0, partial_success: 0 },
user_feedback_stats: { like: 2, dislike: 0 },
admin_feedback_stats: { like: 0, dislike: 1 },
updated_at: 1710000000,
created_at: 1710000000,
},
{
id: 'conversation-success',
name: 'Success row',
from_account_name: 'user-b',
read_at: 1710000001,
message_count: 4,
status_count: { paused: 0, success: 4, failed: 0, partial_success: 0 },
user_feedback_stats: { like: 0, dislike: 0 },
admin_feedback_stats: { like: 0, dislike: 0 },
updated_at: 1710000000,
created_at: 1710000000,
},
{
id: 'conversation-partial',
name: 'Partial row',
from_account_name: 'user-c',
read_at: 1710000001,
message_count: 5,
status_count: { paused: 0, success: 3, failed: 0, partial_success: 1 },
user_feedback_stats: { like: 0, dislike: 0 },
admin_feedback_stats: { like: 0, dislike: 0 },
updated_at: 1710000000,
created_at: 1710000000,
},
{
id: 'conversation-failure',
name: 'Failure row',
from_account_name: 'user-d',
read_at: 1710000001,
message_count: 1,
status_count: { paused: 0, success: 0, failed: 2, partial_success: 0 },
user_feedback_stats: { like: 0, dislike: 0 },
admin_feedback_stats: { like: 0, dislike: 0 },
updated_at: 1710000000,
created_at: 1710000000,
},
],
} as any,
})
expect(screen.getByText('Pending')).toBeInTheDocument()
expect(screen.getByText('Success')).toBeInTheDocument()
@ -522,7 +518,6 @@ describe('ConversationList', () => {
})
it('should support annotation changes, modal closing, and paginated scroll loading in the detail drawer', async () => {
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
mockChatConversationDetail = {
id: 'conversation-1',
created_at: 1710000000,
@ -568,13 +563,9 @@ describe('ConversationList', () => {
has_more: false,
})
render(
<ConversationList
appDetail={{ id: 'app-1', mode: AppModeEnum.CHAT } as any}
logs={createLogs() as any}
onRefresh={mockOnRefresh}
/>,
)
renderConversationList({
searchParams: '?page=2&conversation_id=conversation-1',
})
await waitFor(() => {
expect(screen.getByTestId('chat-panel')).toBeInTheDocument()
@ -609,7 +600,6 @@ describe('ConversationList', () => {
})
it('should close the prompt log modal from completion detail drawers', async () => {
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
mockCompletionConversationDetail = {
id: 'conversation-1',
created_at: 1710000000,
@ -636,13 +626,11 @@ describe('ConversationList', () => {
mockShowPromptLogModal = true
mockCurrentLogItem = { id: 'log-2' }
render(
<ConversationList
appDetail={{ id: 'app-1', mode: AppModeEnum.COMPLETION } as any}
logs={createCompletionLogs() as any}
onRefresh={mockOnRefresh}
/>,
)
renderConversationList({
appDetail: { id: 'app-1', mode: AppModeEnum.COMPLETION } as any,
logs: createCompletionLogs() as any,
searchParams: '?page=2&conversation_id=conversation-1',
})
expect(await screen.findByTestId('prompt-log-modal')).toBeInTheDocument()

View File

@ -13,6 +13,7 @@ import dayjs from 'dayjs'
import timezone from 'dayjs/plugin/timezone'
import utc from 'dayjs/plugin/utc'
import { noop } from 'es-toolkit/function'
import { parseAsString, useQueryState } from 'nuqs'
import * as React from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
@ -33,7 +34,6 @@ import { WorkflowContextProvider } from '@/app/components/workflow/context'
import { useAppContext } from '@/context/app-context'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import useTimestamp from '@/hooks/use-timestamp'
import { usePathname, useRouter, useSearchParams } from '@/next/navigation'
import { fetchChatMessages, updateLogMessageAnnotations, updateLogMessageFeedbacks } from '@/service/log'
import { AppSourceType } from '@/service/share'
import { useChatConversationDetail, useCompletionConversationDetail } from '@/service/use-log'
@ -46,7 +46,6 @@ import {
applyAnnotationEdited,
applyAnnotationRemoved,
buildChatThreadState,
buildConversationUrl,
getCompletionMessageFiles,
getConversationRowValues,
getDetailVarList,
@ -674,10 +673,7 @@ const ChatConversationDetailComp: FC<{ appId?: string, conversationId?: string }
const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh }) => {
const { t } = useTranslation()
const { formatTime } = useTimestamp()
const router = useRouter()
const pathname = usePathname()
const searchParams = useSearchParams()
const conversationIdInUrl = searchParams.get('conversation_id') ?? undefined
const [conversationIdInUrl, setConversationIdInUrl] = useQueryState('conversation_id', parseAsString)
const media = useBreakpoints()
const isMobile = media === MediaType.mobile
@ -697,8 +693,6 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
const activeConversationId = conversationIdInUrl ?? pendingConversationIdRef.current ?? currentConversation?.id
const buildUrlWithConversation = useCallback((conversationId?: string) => buildConversationUrl(pathname, searchParams.toString(), conversationId), [pathname, searchParams])
const handleRowClick = useCallback((log: ConversationListItem) => {
if (conversationIdInUrl === log.id) {
if (!showDrawer)
@ -717,8 +711,8 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
if (currentConversation?.id !== log.id)
setCurrentConversation(undefined)
router.push(buildUrlWithConversation(log.id), { scroll: false })
}, [buildUrlWithConversation, conversationIdInUrl, currentConversation, router, showDrawer])
void setConversationIdInUrl(log.id, { history: 'push' })
}, [conversationIdInUrl, currentConversation, setConversationIdInUrl, showDrawer])
const currentConversationId = currentConversation?.id
@ -755,7 +749,7 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
if (pendingConversationCacheRef.current?.id === conversationIdInUrl || matchedConversation)
pendingConversationCacheRef.current = undefined
}, [conversationIdInUrl, currentConversation, isChatMode, logs?.data, showDrawer])
}, [conversationIdInUrl, currentConversation, currentConversationId, logs?.data, showDrawer])
const onCloseDrawer = useCallback(() => {
onRefresh()
@ -769,8 +763,8 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
closingConversationIdRef.current = conversationIdInUrl ?? null
if (conversationIdInUrl)
router.replace(buildUrlWithConversation(), { scroll: false })
}, [buildUrlWithConversation, conversationIdInUrl, onRefresh, router, setShowAgentLogModal, setShowMessageLogModal, setShowPromptLogModal])
void setConversationIdInUrl(null, { history: 'replace' })
}, [conversationIdInUrl, onRefresh, setConversationIdInUrl, setShowAgentLogModal, setShowMessageLogModal, setShowPromptLogModal])
// Annotated data needs to be highlighted
const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => {