mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 18:27:19 +08:00
Merge branch 'main' into jzh
This commit is contained in:
commit
5c93d74dec
@ -109,6 +109,7 @@ S3_BUCKET_NAME=your-bucket-name
|
||||
S3_ACCESS_KEY=your-access-key
|
||||
S3_SECRET_KEY=your-secret-key
|
||||
S3_REGION=your-region
|
||||
S3_ADDRESS_STYLE=auto
|
||||
|
||||
# Workflow run and Conversation archive storage (S3-compatible)
|
||||
ARCHIVE_STORAGE_ENABLED=false
|
||||
|
||||
@ -34,9 +34,10 @@ from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportMode
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DataSource,
|
||||
InfoList,
|
||||
|
||||
@ -17,8 +17,9 @@ from fields.app_fields import (
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
@ -19,7 +19,7 @@ from fields.rag_pipeline_fields import (
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,8 @@ from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App
|
||||
from models.account import AccountStatus
|
||||
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.entities.dsl_entities import ImportMode, ImportStatus
|
||||
|
||||
|
||||
class InnerAppDSLImportPayload(BaseModel):
|
||||
|
||||
@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Float, create_engine, insert, select, text
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
|
||||
@ -55,9 +55,8 @@ class PGVectoRS(BaseVector):
|
||||
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
)
|
||||
self._client = create_engine(self._url)
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
|
||||
session.commit()
|
||||
self._fields: list[str] = []
|
||||
|
||||
class _Table(CollectionORM):
|
||||
@ -88,7 +87,7 @@ class PGVectoRS(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
create_statement = sql_text(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
||||
id UUID PRIMARY KEY,
|
||||
@ -111,12 +110,11 @@ class PGVectoRS(BaseVector):
|
||||
$$);
|
||||
""")
|
||||
session.execute(index_statement)
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
pks = []
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
for document, embedding in zip(documents, embeddings):
|
||||
pk = uuid4()
|
||||
session.execute(
|
||||
@ -128,7 +126,6 @@ class PGVectoRS(BaseVector):
|
||||
),
|
||||
)
|
||||
pks.append(pk)
|
||||
session.commit()
|
||||
|
||||
return pks
|
||||
|
||||
@ -145,10 +142,9 @@ class PGVectoRS(BaseVector):
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {"ids": ids})
|
||||
session.commit()
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
with Session(self._client) as session:
|
||||
@ -159,15 +155,13 @@ class PGVectoRS(BaseVector):
|
||||
if result:
|
||||
ids = [item[0] for item in result]
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {"ids": ids})
|
||||
session.commit()
|
||||
|
||||
def delete(self):
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}"))
|
||||
session.commit()
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self._client) as session:
|
||||
|
||||
@ -75,22 +75,27 @@ class ToolProviderApiEntity(BaseModel):
|
||||
parameter.pop("input_schema", None)
|
||||
# -------------
|
||||
optional_fields = self.optional_field("server_url", self.server_url)
|
||||
if self.type == ToolProviderType.MCP:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
optional_fields.update(
|
||||
self.optional_field(
|
||||
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
|
||||
match self.type:
|
||||
case ToolProviderType.MCP:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
optional_fields.update(
|
||||
self.optional_field(
|
||||
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
|
||||
)
|
||||
)
|
||||
)
|
||||
optional_fields.update(
|
||||
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
|
||||
)
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
elif self.type == ToolProviderType.WORKFLOW:
|
||||
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
|
||||
optional_fields.update(
|
||||
self.optional_field(
|
||||
"authentication", self.authentication.model_dump() if self.authentication else None
|
||||
)
|
||||
)
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
case ToolProviderType.WORKFLOW:
|
||||
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
|
||||
case _:
|
||||
pass
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
|
||||
@ -11,6 +11,7 @@ from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
@ -166,13 +167,7 @@ class ToolFileManager:
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == id).limit(1))
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
@ -190,13 +185,7 @@ class ToolFileManager:
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
message_file: MessageFile | None = (
|
||||
session.query(MessageFile)
|
||||
.where(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
message_file: MessageFile | None = session.scalar(select(MessageFile).where(MessageFile.id == id).limit(1))
|
||||
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
@ -210,13 +199,7 @@ class ToolFileManager:
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
@ -234,13 +217,7 @@ class ToolFileManager:
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
|
||||
|
||||
if not tool_file:
|
||||
return None, None
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Mapping
|
||||
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from pydantic import Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
@ -96,10 +97,10 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
:param app: the app
|
||||
:return: the tool
|
||||
"""
|
||||
workflow: Workflow | None = (
|
||||
session.query(Workflow)
|
||||
workflow: Workflow | None = session.scalar(
|
||||
select(Workflow)
|
||||
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
@ -217,13 +218,13 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
return self.tools
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
db_provider: WorkflowToolProvider | None = (
|
||||
session.query(WorkflowToolProvider)
|
||||
db_provider: WorkflowToolProvider | None = session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == self.provider_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not db_provider:
|
||||
|
||||
@ -813,56 +813,32 @@ class AppModelConfig(TypeBase):
|
||||
"file_upload": self.file_upload_dict,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dump_optional(value: Any) -> str | None:
|
||||
return json.dumps(value) if value else None
|
||||
|
||||
def from_model_config_dict(self, model_config: AppModelConfigDict):
|
||||
self.opening_statement = model_config.get("opening_statement")
|
||||
self.suggested_questions = (
|
||||
json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None
|
||||
)
|
||||
self.suggested_questions_after_answer = (
|
||||
json.dumps(model_config.get("suggested_questions_after_answer"))
|
||||
if model_config.get("suggested_questions_after_answer")
|
||||
else None
|
||||
)
|
||||
self.speech_to_text = (
|
||||
json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None
|
||||
)
|
||||
self.text_to_speech = (
|
||||
json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None
|
||||
)
|
||||
self.more_like_this = (
|
||||
json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None
|
||||
)
|
||||
self.sensitive_word_avoidance = (
|
||||
json.dumps(model_config.get("sensitive_word_avoidance"))
|
||||
if model_config.get("sensitive_word_avoidance")
|
||||
else None
|
||||
)
|
||||
self.external_data_tools = (
|
||||
json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None
|
||||
)
|
||||
self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None
|
||||
self.user_input_form = (
|
||||
json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None
|
||||
self.suggested_questions = self._dump_optional(model_config.get("suggested_questions"))
|
||||
self.suggested_questions_after_answer = self._dump_optional(
|
||||
model_config.get("suggested_questions_after_answer")
|
||||
)
|
||||
self.speech_to_text = self._dump_optional(model_config.get("speech_to_text"))
|
||||
self.text_to_speech = self._dump_optional(model_config.get("text_to_speech"))
|
||||
self.more_like_this = self._dump_optional(model_config.get("more_like_this"))
|
||||
self.sensitive_word_avoidance = self._dump_optional(model_config.get("sensitive_word_avoidance"))
|
||||
self.external_data_tools = self._dump_optional(model_config.get("external_data_tools"))
|
||||
self.model = self._dump_optional(model_config.get("model"))
|
||||
self.user_input_form = self._dump_optional(model_config.get("user_input_form"))
|
||||
self.dataset_query_variable = model_config.get("dataset_query_variable")
|
||||
self.pre_prompt = model_config.get("pre_prompt")
|
||||
self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None
|
||||
self.retriever_resource = (
|
||||
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
|
||||
)
|
||||
self.agent_mode = self._dump_optional(model_config.get("agent_mode"))
|
||||
self.retriever_resource = self._dump_optional(model_config.get("retriever_resource"))
|
||||
self.prompt_type = PromptType(model_config.get("prompt_type", "simple"))
|
||||
self.chat_prompt_config = (
|
||||
json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None
|
||||
)
|
||||
self.completion_prompt_config = (
|
||||
json.dumps(model_config.get("completion_prompt_config"))
|
||||
if model_config.get("completion_prompt_config")
|
||||
else None
|
||||
)
|
||||
self.dataset_configs = (
|
||||
json.dumps(model_config.get("dataset_configs")) if model_config.get("dataset_configs") else None
|
||||
)
|
||||
self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None
|
||||
self.chat_prompt_config = self._dump_optional(model_config.get("chat_prompt_config"))
|
||||
self.completion_prompt_config = self._dump_optional(model_config.get("completion_prompt_config"))
|
||||
self.dataset_configs = self._dump_optional(model_config.get("dataset_configs"))
|
||||
self.file_upload = self._dump_optional(model_config.get("file_upload"))
|
||||
return self
|
||||
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
@ -19,7 +18,7 @@ from graphon.nodes.question_classifier.entities import QuestionClassifierNodeDat
|
||||
from graphon.nodes.tool.entities import ToolNodeData
|
||||
from packaging import version
|
||||
from packaging.version import parse as parse_version
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -40,6 +39,7 @@ from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, AppMode
|
||||
from models.model import AppModelConfig, AppModelConfigDict, IconType
|
||||
from models.workflow import Workflow
|
||||
from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -53,18 +53,6 @@ DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.6.0"
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
YAML_CONTENT = "yaml-content"
|
||||
YAML_URL = "yaml-url"
|
||||
|
||||
|
||||
class ImportStatus(StrEnum):
|
||||
COMPLETED = "completed"
|
||||
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
|
||||
PENDING = "pending"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class Import(BaseModel):
|
||||
id: str
|
||||
status: ImportStatus
|
||||
@ -75,10 +63,6 @@ class Import(BaseModel):
|
||||
error: str = ""
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
"""Determine import status based on version comparison"""
|
||||
try:
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.provider_entities import FormType
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
@ -53,13 +53,12 @@ class DatasourceProviderService:
|
||||
"""
|
||||
remove oauth custom client params
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(DatasourceOauthTenantParamConfig).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
def decrypt_datasource_provider_credentials(
|
||||
self,
|
||||
@ -109,7 +108,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
get credential by id
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
if credential_id:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
@ -156,7 +155,6 @@ class DatasourceProviderService:
|
||||
datasource_provider=datasource_provider,
|
||||
)
|
||||
datasource_provider.expires_at = refreshed_credentials.expires_at
|
||||
session.commit()
|
||||
|
||||
return self.decrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
@ -174,7 +172,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
get all datasource credentials by provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_providers = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
@ -224,7 +222,6 @@ class DatasourceProviderService:
|
||||
provider=provider,
|
||||
)
|
||||
real_credentials_list.append(real_credentials)
|
||||
session.commit()
|
||||
|
||||
return real_credentials_list
|
||||
|
||||
@ -234,7 +231,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
update datasource provider name
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
@ -266,7 +263,6 @@ class DatasourceProviderService:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
|
||||
target_provider.name = name
|
||||
session.commit()
|
||||
return
|
||||
|
||||
def set_default_datasource_provider(
|
||||
@ -275,7 +271,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
set default datasource provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get provider
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
@ -300,7 +296,6 @@ class DatasourceProviderService:
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
def setup_oauth_custom_client_params(
|
||||
@ -315,7 +310,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
if client_params is None and enabled is None:
|
||||
return
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
@ -349,7 +344,6 @@ class DatasourceProviderService:
|
||||
|
||||
if enabled is not None:
|
||||
tenant_oauth_client_params.enabled = enabled
|
||||
session.commit()
|
||||
|
||||
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
|
||||
"""
|
||||
@ -488,7 +482,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
update datasource oauth provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
target_provider = (
|
||||
@ -535,7 +529,6 @@ class DatasourceProviderService:
|
||||
target_provider.expires_at = expire_at
|
||||
target_provider.encrypted_credentials = credentials
|
||||
target_provider.avatar_url = avatar_url or target_provider.avatar_url
|
||||
session.commit()
|
||||
|
||||
def add_datasource_oauth_provider(
|
||||
self,
|
||||
@ -550,7 +543,7 @@ class DatasourceProviderService:
|
||||
add datasource oauth provider
|
||||
"""
|
||||
credential_type = CredentialType.OAUTH2
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
|
||||
with redis_client.lock(lock, timeout=60):
|
||||
db_provider_name = name
|
||||
@ -604,7 +597,6 @@ class DatasourceProviderService:
|
||||
expires_at=expire_at,
|
||||
)
|
||||
session.add(datasource_provider)
|
||||
session.commit()
|
||||
|
||||
def add_datasource_api_key_provider(
|
||||
self,
|
||||
@ -623,7 +615,7 @@ class DatasourceProviderService:
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
db_provider_name = name or self.generate_next_datasource_provider_name(
|
||||
@ -670,7 +662,6 @@ class DatasourceProviderService:
|
||||
encrypted_credentials=credentials,
|
||||
)
|
||||
session.add(datasource_provider)
|
||||
session.commit()
|
||||
|
||||
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
|
||||
"""
|
||||
@ -926,7 +917,7 @@ class DatasourceProviderService:
|
||||
update datasource credentials.
|
||||
"""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
|
||||
@ -980,7 +971,6 @@ class DatasourceProviderService:
|
||||
encrypted_credentials[key] = value
|
||||
|
||||
datasource_provider.encrypted_credentials = encrypted_credentials
|
||||
session.commit()
|
||||
|
||||
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
|
||||
"""
|
||||
|
||||
21
api/services/entities/dsl_entities.py
Normal file
21
api/services/entities/dsl_entities.py
Normal file
@ -0,0 +1,21 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
YAML_CONTENT = "yaml-content"
|
||||
YAML_URL = "yaml-url"
|
||||
|
||||
|
||||
class ImportStatus(StrEnum):
|
||||
COMPLETED = "completed"
|
||||
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
|
||||
PENDING = "pending"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
@ -1,3 +1,4 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -8,10 +9,10 @@ class PluginAutoUpgradeService:
|
||||
@staticmethod
|
||||
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
return (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
return session.scalar(
|
||||
select(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -24,10 +25,10 @@ class PluginAutoUpgradeService:
|
||||
include_plugins: list[str],
|
||||
) -> bool:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
exist_strategy = session.scalar(
|
||||
select(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not exist_strategy:
|
||||
strategy = TenantPluginAutoUpgradeStrategy(
|
||||
@ -51,10 +52,10 @@ class PluginAutoUpgradeService:
|
||||
@staticmethod
|
||||
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
exist_strategy = session.scalar(
|
||||
select(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not exist_strategy:
|
||||
# create for this tenant
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -8,7 +9,9 @@ class PluginPermissionService:
|
||||
@staticmethod
|
||||
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
return session.scalar(
|
||||
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def change_permission(
|
||||
@ -17,8 +20,8 @@ class PluginPermissionService:
|
||||
debug_permission: TenantPluginPermission.DebugPermission,
|
||||
):
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
permission = (
|
||||
session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
permission = session.scalar(
|
||||
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
if not permission:
|
||||
permission = TenantPluginPermission(
|
||||
|
||||
@ -555,7 +555,7 @@ class RagPipelineService:
|
||||
workflow_node_execution.id
|
||||
)
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=pipeline.id,
|
||||
@ -569,7 +569,6 @@ class RagPipelineService:
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
)
|
||||
session.commit()
|
||||
if isinstance(workflow_node_execution_db_model, WorkflowNodeExecutionModel):
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=workflow_node_execution_db_model,
|
||||
@ -1325,7 +1324,7 @@ class RagPipelineService:
|
||||
# Convert node_execution to WorkflowNodeExecution after save
|
||||
workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) # type: ignore
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=pipeline.id,
|
||||
@ -1339,7 +1338,6 @@ class RagPipelineService:
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
)
|
||||
session.commit()
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=workflow_node_execution_db_model,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
|
||||
@ -20,7 +20,7 @@ from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeDat
|
||||
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from graphon.nodes.tool.entities import ToolNodeData
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -37,7 +37,7 @@ from models import Account
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
|
||||
from models.enums import CollectionBindingType, DatasetRuntimeMode
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from services.app_dsl_service import ImportMode, ImportStatus
|
||||
from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
IconInfo,
|
||||
KnowledgeConfiguration,
|
||||
@ -64,10 +64,6 @@ class RagPipelineImportInfo(BaseModel):
|
||||
dataset_id: str | None = None
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
"""Determine import status based on version comparison"""
|
||||
try:
|
||||
|
||||
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
@ -46,13 +46,12 @@ class BuiltinToolManageService:
|
||||
delete custom oauth client params
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(ToolOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@ -150,7 +149,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
update builtin tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get if the provider exists
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
@ -203,9 +202,7 @@ class BuiltinToolManageService:
|
||||
|
||||
db_provider.name = name
|
||||
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
return {"result": "success"}
|
||||
|
||||
@ -222,7 +219,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
add builtin tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
try:
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
@ -281,9 +278,7 @@ class BuiltinToolManageService:
|
||||
)
|
||||
|
||||
session.add(db_provider)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
|
||||
return {"result": "success"}
|
||||
@ -379,7 +374,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
@ -393,7 +388,6 @@ class BuiltinToolManageService:
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
|
||||
session.delete(db_provider)
|
||||
session.commit()
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
@ -409,7 +403,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
set default provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get provider
|
||||
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
|
||||
if target_provider is None:
|
||||
@ -422,7 +416,6 @@ class BuiltinToolManageService:
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -654,7 +647,7 @@ class BuiltinToolManageService:
|
||||
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
|
||||
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
custom_client_params = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
@ -690,7 +683,6 @@ class BuiltinToolManageService:
|
||||
if enable_oauth_custom_client is not None:
|
||||
custom_client_params.enabled = enable_oauth_custom_client
|
||||
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -48,21 +48,25 @@ class ToolTransformService:
|
||||
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
|
||||
)
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
return str(url_prefix / "builtin" / provider_name / "icon")
|
||||
elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}:
|
||||
try:
|
||||
if isinstance(icon, str):
|
||||
parsed = emoji_icon_adapter.validate_json(icon)
|
||||
return {"background": parsed["background"], "content": parsed["content"]}
|
||||
return {"background": icon["background"], "content": icon["content"]}
|
||||
except (ValueError, ValidationError, KeyError):
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
if isinstance(icon, Mapping):
|
||||
return {"background": icon.get("background", ""), "content": icon.get("content", "")}
|
||||
return icon
|
||||
return ""
|
||||
match provider_type:
|
||||
case ToolProviderType.BUILT_IN:
|
||||
return str(url_prefix / "builtin" / provider_name / "icon")
|
||||
case ToolProviderType.API | ToolProviderType.WORKFLOW:
|
||||
try:
|
||||
if isinstance(icon, str):
|
||||
parsed = emoji_icon_adapter.validate_json(icon)
|
||||
return {"background": parsed["background"], "content": parsed["content"]}
|
||||
return {"background": icon["background"], "content": icon["content"]}
|
||||
except (ValueError, ValidationError, KeyError):
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
case ToolProviderType.MCP:
|
||||
if isinstance(icon, Mapping):
|
||||
return {"background": icon.get("background", ""), "content": icon.get("content", "")}
|
||||
return icon
|
||||
case ToolProviderType.PLUGIN | ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL:
|
||||
return ""
|
||||
case _:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
|
||||
|
||||
@ -1075,9 +1075,8 @@ class DraftVariableSaver:
|
||||
)
|
||||
engine = bind = self._session.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
with Session(bind=engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=engine, expire_on_commit=False).begin() as session:
|
||||
session.add(variable_file)
|
||||
session.commit()
|
||||
|
||||
return truncation_result.result, variable_file
|
||||
|
||||
|
||||
@ -837,7 +837,7 @@ class WorkflowService:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
outputs = workflow_node_execution.load_full_outputs(session, storage)
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=app_model.id,
|
||||
@ -848,7 +848,6 @@ class WorkflowService:
|
||||
user=account,
|
||||
)
|
||||
draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs)
|
||||
session.commit()
|
||||
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=workflow_node_execution,
|
||||
@ -977,7 +976,7 @@ class WorkflowService:
|
||||
|
||||
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
|
||||
enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=app_model.id,
|
||||
@ -988,7 +987,6 @@ class WorkflowService:
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
)
|
||||
draft_var_saver.save(outputs=outputs, process_data={})
|
||||
session.commit()
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@ -112,7 +112,9 @@ def clean_dataset_task(
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
|
||||
image_files = session.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
).all()
|
||||
for image_file in image_files:
|
||||
if image_file is None:
|
||||
continue
|
||||
@ -150,20 +152,22 @@ def clean_dataset_task(
|
||||
)
|
||||
session.execute(binding_delete_stmt)
|
||||
|
||||
session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
|
||||
session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
|
||||
session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
|
||||
session.execute(delete(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id))
|
||||
session.execute(delete(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id))
|
||||
session.execute(delete(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id))
|
||||
# delete dataset metadata
|
||||
session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
|
||||
session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
|
||||
session.execute(delete(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id))
|
||||
session.execute(delete(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id))
|
||||
# delete pipeline and workflow
|
||||
if pipeline_id:
|
||||
session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
|
||||
session.query(Workflow).where(
|
||||
Workflow.tenant_id == tenant_id,
|
||||
Workflow.app_id == pipeline_id,
|
||||
Workflow.type == WorkflowType.RAG_PIPELINE,
|
||||
).delete()
|
||||
session.execute(delete(Pipeline).where(Pipeline.id == pipeline_id))
|
||||
session.execute(
|
||||
delete(Workflow).where(
|
||||
Workflow.tenant_id == tenant_id,
|
||||
Workflow.app_id == pipeline_id,
|
||||
Workflow.type == WorkflowType.RAG_PIPELINE,
|
||||
)
|
||||
)
|
||||
# delete files
|
||||
if documents:
|
||||
file_ids = []
|
||||
@ -174,7 +178,7 @@ def clean_dataset_task(
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_ids.append(file_id)
|
||||
files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
|
||||
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
|
||||
for file in files:
|
||||
storage.delete(file.key)
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -29,7 +29,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
@ -49,23 +49,24 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(dataset_documents_ids))
|
||||
.values(indexing_status="indexing")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
@ -82,13 +83,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="completed")
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="error", error=str(e))
|
||||
)
|
||||
session.commit()
|
||||
elif action == "update":
|
||||
@ -104,8 +109,10 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(dataset_documents_ids))
|
||||
.values(indexing_status="indexing")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@ -115,15 +122,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if segments:
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
@ -172,13 +178,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
index_processor.load(
|
||||
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||
)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="completed")
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="error", error=str(e))
|
||||
)
|
||||
session.commit()
|
||||
else:
|
||||
|
||||
@ -32,7 +32,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
tenant_id = None
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
document = session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
@ -42,7 +44,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
|
||||
return
|
||||
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
|
||||
@ -87,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||
@ -112,7 +114,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
try:
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if dataset:
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
|
||||
@ -120,7 +122,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
logger.exception("Failed to clean vector index for document %s", document_id)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if not document:
|
||||
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
|
||||
return
|
||||
@ -140,7 +142,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if document:
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
@ -150,7 +152,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
except Exception as e:
|
||||
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(e)
|
||||
|
||||
@ -53,6 +53,31 @@ def _session_factory(calls, execute_results=None):
|
||||
return _session
|
||||
|
||||
|
||||
class _FakeBeginContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
def _sessionmaker_factory(calls, execute_results=None):
|
||||
def _sessionmaker(*args, **kwargs):
|
||||
session = _FakeSessionContext(calls=calls, execute_results=execute_results)
|
||||
return MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
|
||||
|
||||
return _sessionmaker
|
||||
|
||||
|
||||
def _patch_both(monkeypatch, module, calls, execute_results=None):
|
||||
"""Patch both Session and sessionmaker on the module with the same call tracker."""
|
||||
monkeypatch.setattr(module, "Session", _session_factory(calls, execute_results))
|
||||
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(calls, execute_results))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pgvecto_module(monkeypatch):
|
||||
for name, module in _build_fake_pgvecto_modules().items():
|
||||
@ -105,7 +130,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
|
||||
_patch_both(monkeypatch, module, session_calls)
|
||||
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
vector.create_collection = MagicMock()
|
||||
@ -124,7 +149,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
|
||||
_patch_both(monkeypatch, module, session_calls)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
@ -151,10 +176,10 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])]
|
||||
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
|
||||
_patch_both(monkeypatch, module, init_calls)
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results)))
|
||||
_patch_both(monkeypatch, module, runtime_calls, execute_results=list(execute_results))
|
||||
|
||||
class _InsertBuilder:
|
||||
def __init__(self, table):
|
||||
@ -179,6 +204,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
"Session",
|
||||
_session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]),
|
||||
)
|
||||
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
@ -204,12 +230,13 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
],
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
|
||||
vector.delete_by_ids(["doc-1"])
|
||||
assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls)
|
||||
assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls)
|
||||
|
||||
runtime_calls.clear()
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()]))
|
||||
_patch_both(monkeypatch, module, runtime_calls, execute_results=[MagicMock()])
|
||||
vector.delete()
|
||||
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
|
||||
|
||||
@ -218,7 +245,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
init_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
|
||||
_patch_both(monkeypatch, module, init_calls)
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
|
||||
runtime_calls = []
|
||||
@ -277,7 +304,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
|
||||
(SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1),
|
||||
(SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8),
|
||||
]
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows]))
|
||||
_patch_both(monkeypatch, module, runtime_calls, execute_results=[rows])
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
|
||||
assert len(docs) == 1
|
||||
|
||||
@ -129,7 +129,7 @@ def test_get_file_binary_returns_none_when_not_found() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -144,7 +144,7 @@ def test_get_file_binary_returns_bytes_when_found() -> None:
|
||||
manager = ToolFileManager()
|
||||
tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain")
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = tool_file
|
||||
session.scalar.return_value = tool_file
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
@ -160,11 +160,7 @@ def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = None
|
||||
second_query.where.return_value.first.return_value = None
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
session.scalar.side_effect = [None, None]
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -179,11 +175,7 @@ def test_get_file_binary_by_message_file_id_when_url_is_none() -> None:
|
||||
manager = ToolFileManager()
|
||||
message_file = SimpleNamespace(url=None)
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = message_file
|
||||
second_query.where.return_value.first.return_value = None
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
session.scalar.side_effect = [message_file, None]
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -199,11 +191,7 @@ def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None:
|
||||
message_file = SimpleNamespace(url="https://x/files/tools/tool123.png")
|
||||
tool_file = SimpleNamespace(file_key="k2", mimetype="image/png")
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = message_file
|
||||
second_query.where.return_value.first.return_value = tool_file
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
session.scalar.side_effect = [message_file, tool_file]
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
@ -219,7 +207,7 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -242,7 +230,7 @@ def test_get_file_generator_returns_stream_when_found() -> None:
|
||||
size=12,
|
||||
)
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = tool_file
|
||||
session.scalar.return_value = tool_file
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
|
||||
@ -43,7 +43,7 @@ def test_get_db_provider_tool_builds_entity():
|
||||
controller = _controller()
|
||||
session = Mock()
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={})
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
app = SimpleNamespace(id="app-1")
|
||||
db_provider = SimpleNamespace(
|
||||
id="provider-1",
|
||||
@ -136,7 +136,7 @@ def test_from_db_builds_controller():
|
||||
parameter_configurations=[],
|
||||
)
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
session.get.side_effect = [app, user]
|
||||
fake_cm = MagicMock()
|
||||
fake_cm.__enter__.return_value = session
|
||||
@ -163,7 +163,7 @@ def test_get_tools_returns_empty_when_provider_missing():
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
session_cls.return_value.__enter__.return_value = session
|
||||
|
||||
assert controller.get_tools("tenant-1") == []
|
||||
@ -189,7 +189,7 @@ def test_get_tools_raises_when_app_missing():
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
session.get.return_value = None
|
||||
session_cls.return_value.__enter__.return_value = session
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
|
||||
@ -20,7 +20,7 @@ class TestGetStrategy:
|
||||
def test_returns_strategy_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
strategy = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = strategy
|
||||
session.scalar.return_value = strategy
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
@ -31,7 +31,7 @@ class TestGetStrategy:
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
@ -44,9 +44,9 @@ class TestGetStrategy:
|
||||
class TestChangeStrategy:
|
||||
def test_creates_new_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
@ -65,7 +65,7 @@ class TestChangeStrategy:
|
||||
def test_updates_existing_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
@ -90,11 +90,12 @@ class TestChangeStrategy:
|
||||
class TestExcludePlugin:
|
||||
def test_creates_default_strategy_when_none_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with (
|
||||
p1,
|
||||
p2,
|
||||
patch(f"{MODULE}.select"),
|
||||
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
|
||||
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
|
||||
):
|
||||
@ -113,9 +114,9 @@ class TestExcludePlugin:
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p-existing"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
@ -131,9 +132,9 @@ class TestExcludePlugin:
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "partial"
|
||||
existing.include_plugins = ["p1", "p2"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
@ -148,9 +149,9 @@ class TestExcludePlugin:
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "all"
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
@ -167,9 +168,9 @@ class TestExcludePlugin:
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p1"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
|
||||
@ -20,7 +20,7 @@ class TestGetPermission:
|
||||
def test_returns_permission_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
permission = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = permission
|
||||
session.scalar.return_value = permission
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
@ -31,7 +31,7 @@ class TestGetPermission:
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
@ -44,9 +44,9 @@ class TestGetPermission:
|
||||
class TestChangePermission:
|
||||
def test_creates_new_permission_when_not_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||
perm_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
@ -59,7 +59,7 @@ class TestChangePermission:
|
||||
def test_updates_existing_permission(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
@ -40,7 +40,10 @@ class TestDatasourceProviderService:
|
||||
q returns itself for .filter_by(), .order_by(), .where() so any
|
||||
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
|
||||
"""
|
||||
with patch("services.datasource_provider_service.Session") as mock_cls:
|
||||
with (
|
||||
patch("services.datasource_provider_service.Session") as mock_cls,
|
||||
patch("services.datasource_provider_service.sessionmaker") as mock_sm,
|
||||
):
|
||||
sess = MagicMock(spec=Session)
|
||||
|
||||
q = MagicMock()
|
||||
@ -63,6 +66,8 @@ class TestDatasourceProviderService:
|
||||
|
||||
mock_cls.return_value.__enter__.return_value = sess
|
||||
mock_cls.return_value.no_autoflush.__enter__.return_value = sess
|
||||
mock_sm.return_value.begin.return_value.__enter__.return_value = sess
|
||||
mock_sm.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
yield sess
|
||||
|
||||
@ -266,7 +271,6 @@ class TestDatasourceProviderService:
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}),
|
||||
):
|
||||
service.get_datasource_credentials("t1", "prov", "org/plug")
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user):
|
||||
"""API key credentials with expires_at=-1 skip refresh and return directly."""
|
||||
@ -333,7 +337,6 @@ class TestDatasourceProviderService:
|
||||
p.name = "same"
|
||||
mock_db_session.query().first.return_value = p
|
||||
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
@ -352,7 +355,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.query().count.return_value = 0
|
||||
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
|
||||
assert p.name == "new_name"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# set_default_datasource_provider (lines 277-303)
|
||||
@ -370,7 +372,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.query().first.return_value = target
|
||||
service.set_default_datasource_provider("t1", make_id(), "new-id")
|
||||
assert target.is_default is True
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# get_oauth_encrypter (lines 404-420)
|
||||
@ -460,7 +461,6 @@ class TestDatasourceProviderService:
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
|
||||
"""Conflict on name results in auto-incremented name, not an error."""
|
||||
@ -512,7 +512,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
@ -523,7 +522,6 @@ class TestDatasourceProviderService:
|
||||
service.reauthorize_datasource_oauth_provider(
|
||||
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
|
||||
)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
@ -571,7 +569,6 @@ class TestDatasourceProviderService:
|
||||
):
|
||||
service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"})
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
@ -747,7 +744,6 @@ class TestDatasourceProviderService:
|
||||
# encrypter must have been called with the new secret value
|
||||
self._enc.encrypt_token.assert_called()
|
||||
# commit must be called exactly once
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# remove_datasource_credentials (lines 980-997)
|
||||
@ -758,7 +754,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.scalar.return_value = p
|
||||
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
|
||||
mock_db_session.delete.assert_called_once_with(p)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session):
|
||||
"""No error raised; no delete called when record doesn't exist (lines 994 branch)."""
|
||||
|
||||
@ -15,17 +15,24 @@ def _mock_session(mock_session_cls):
|
||||
return session
|
||||
|
||||
|
||||
def _mock_sessionmaker(mock_sm_cls):
|
||||
"""Helper: set up a sessionmaker().begin() context manager mock and return the inner session."""
|
||||
session = MagicMock()
|
||||
mock_sm_cls.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||
mock_sm_cls.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return session
|
||||
|
||||
|
||||
class TestDeleteCustomOauthClientParams:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_deletes_and_returns_success(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_deletes_and_returns_success(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
|
||||
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.query.return_value.filter_by.return_value.delete.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestListBuiltinToolProviderTools:
|
||||
@ -138,10 +145,10 @@ class TestIsOauthCustomClientEnabled:
|
||||
class TestDeleteBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
@ -149,10 +156,10 @@ class TestDeleteBuiltinToolProvider:
|
||||
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
db_provider = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
mock_cache = MagicMock()
|
||||
@ -162,24 +169,23 @@ class TestDeleteBuiltinToolProvider:
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.delete.assert_called_once_with(db_provider)
|
||||
session.commit.assert_called_once()
|
||||
mock_cache.delete.assert_called_once()
|
||||
|
||||
|
||||
class TestSetDefaultProvider:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_not_found(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_raises_when_not_found(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="provider not found"):
|
||||
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
||||
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_sets_default_and_clears_old(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
target = MagicMock()
|
||||
session.query.return_value.filter_by.return_value.first.return_value = target
|
||||
|
||||
@ -187,14 +193,13 @@ class TestSetDefaultProvider:
|
||||
|
||||
assert result == {"result": "success"}
|
||||
assert target.is_default is True
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestUpdateBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
@ -203,10 +208,10 @@ class TestUpdateBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||
@patch(f"{MODULE}.CredentialType")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
db_provider = MagicMock(credential_type="api_key", credentials="{}")
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
|
||||
@ -227,7 +232,6 @@ class TestUpdateBuiltinToolProvider:
|
||||
result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"})
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.commit.assert_called_once()
|
||||
mock_cache.delete.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@ -60,12 +60,6 @@ def mock_db_session():
|
||||
cm.__exit__.return_value = None
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
# Setup query chain
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.delete.return_value = 0
|
||||
|
||||
# Setup scalars for select queries
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
@ -220,11 +214,6 @@ class TestPipelineAndWorkflowDeletion:
|
||||
- Pipeline record is deleted
|
||||
- Related workflow record is deleted
|
||||
"""
|
||||
# Arrange
|
||||
mock_query = mock_db_session.session.query.return_value
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.delete.return_value = 1
|
||||
|
||||
# Act
|
||||
clean_dataset_task(
|
||||
dataset_id=dataset_id,
|
||||
@ -236,9 +225,9 @@ class TestPipelineAndWorkflowDeletion:
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
# Assert - verify delete was called for pipeline-related queries
|
||||
# The actual count depends on total queries, but pipeline deletion should add 2 more
|
||||
assert mock_query.delete.call_count >= 7 # 5 base + 2 pipeline/workflow
|
||||
# Assert - verify execute was called for delete operations
|
||||
# 1 attachment JOIN query + 5 base deletes + 2 pipeline/workflow deletes = 8
|
||||
assert mock_db_session.session.execute.call_count >= 8
|
||||
|
||||
def test_clean_dataset_task_without_pipeline_id(
|
||||
self,
|
||||
@ -256,11 +245,6 @@ class TestPipelineAndWorkflowDeletion:
|
||||
Expected behavior:
|
||||
- Pipeline and workflow deletion queries are not executed
|
||||
"""
|
||||
# Arrange
|
||||
mock_query = mock_db_session.session.query.return_value
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.delete.return_value = 1
|
||||
|
||||
# Act
|
||||
clean_dataset_task(
|
||||
dataset_id=dataset_id,
|
||||
@ -272,8 +256,9 @@ class TestPipelineAndWorkflowDeletion:
|
||||
pipeline_id=None,
|
||||
)
|
||||
|
||||
# Assert - verify delete was called only for base queries (5 times)
|
||||
assert mock_query.delete.call_count == 5
|
||||
# Assert - verify execute was called for delete operations
|
||||
# 1 attachment JOIN query + 5 base deletes = 6
|
||||
assert mock_db_session.session.execute.call_count == 6
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@ -80,7 +80,7 @@ def mock_db_session(mock_document, mock_dataset):
|
||||
with patch("tasks.document_indexing_sync_task.session_factory", autospec=True) as mock_session_factory:
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
session.scalar.side_effect = [mock_document, mock_dataset]
|
||||
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
@ -242,14 +242,13 @@ class TestDataSourceInfoSerialization:
|
||||
# DB session mock — shared across all ``session_factory.create_session()`` calls
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
# .where() path: session 1 reads document + dataset, session 2 reads dataset
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
# All .first() calls are now session.scalar() — ordered by call sequence:
|
||||
# session 1: document + dataset, session 2: dataset (clean), session 3: document (update),
|
||||
# session 4: document (indexing)
|
||||
session.scalar.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
mock_dataset,
|
||||
]
|
||||
# .filter_by() path: session 3 (update), session 4 (indexing)
|
||||
session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
|
||||
@ -469,6 +469,7 @@ S3_REGION=us-east-1
|
||||
S3_BUCKET_NAME=difyai
|
||||
S3_ACCESS_KEY=
|
||||
S3_SECRET_KEY=
|
||||
S3_ADDRESS_STYLE=auto
|
||||
# Whether to use AWS managed IAM roles for authenticating with the S3 service.
|
||||
# If set to false, the access key and secret key must be provided.
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
|
||||
@ -131,6 +131,7 @@ x-shared-env: &shared-api-worker-env
|
||||
S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai}
|
||||
S3_ACCESS_KEY: ${S3_ACCESS_KEY:-}
|
||||
S3_SECRET_KEY: ${S3_SECRET_KEY:-}
|
||||
S3_ADDRESS_STYLE: ${S3_ADDRESS_STYLE:-auto}
|
||||
S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false}
|
||||
ARCHIVE_STORAGE_ENABLED: ${ARCHIVE_STORAGE_ENABLED:-false}
|
||||
ARCHIVE_STORAGE_ENDPOINT: ${ARCHIVE_STORAGE_ENDPOINT:-}
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
/* eslint-disable ts/no-explicit-any */
|
||||
import type { ReactNode } from 'react'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { act, fireEvent, screen, waitFor } from '@testing-library/react'
|
||||
import { renderWithNuqs } from '@/test/nuqs-testing'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import ConversationList from '../list'
|
||||
|
||||
const mockFetchChatMessages = vi.fn()
|
||||
const mockUpdateLogMessageFeedbacks = vi.fn()
|
||||
const mockUpdateLogMessageAnnotations = vi.fn()
|
||||
const mockPush = vi.fn()
|
||||
const mockReplace = vi.fn()
|
||||
const mockOnRefresh = vi.fn()
|
||||
const mockSetCurrentLogItem = vi.fn()
|
||||
const mockSetShowPromptLogModal = vi.fn()
|
||||
@ -17,7 +16,6 @@ const mockSetShowMessageLogModal = vi.fn()
|
||||
const mockCompletionRefetch = vi.fn()
|
||||
const mockDelAnnotation = vi.fn()
|
||||
|
||||
let mockSearchParams = new URLSearchParams()
|
||||
let mockChatConversationDetail: Record<string, unknown> | undefined
|
||||
let mockCompletionConversationDetail: Record<string, unknown> | undefined
|
||||
let mockShowMessageLogModal = false
|
||||
@ -53,18 +51,6 @@ vi.mock('@/hooks/use-breakpoints', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
replace: mockReplace,
|
||||
}),
|
||||
usePathname: () => '/apps/app-1/logs',
|
||||
useSearchParams: () => ({
|
||||
get: (key: string) => mockSearchParams.get(key),
|
||||
toString: () => mockSearchParams.toString(),
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-log', () => ({
|
||||
useChatConversationDetail: () => ({
|
||||
data: mockChatConversationDetail,
|
||||
@ -256,10 +242,28 @@ const createChatMessage = (id: string, overrides: Record<string, unknown> = {})
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const renderConversationList = ({
|
||||
appDetail = { id: 'app-1', mode: AppModeEnum.CHAT } as any,
|
||||
logs = createLogs() as any,
|
||||
searchParams = '?page=2',
|
||||
}: {
|
||||
appDetail?: any
|
||||
logs?: any
|
||||
searchParams?: string
|
||||
} = {}) => {
|
||||
return renderWithNuqs(
|
||||
<ConversationList
|
||||
appDetail={appDetail}
|
||||
logs={logs}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
{ searchParams },
|
||||
)
|
||||
}
|
||||
|
||||
describe('ConversationList', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockSearchParams = new URLSearchParams('page=2')
|
||||
mockChatConversationDetail = undefined
|
||||
mockCompletionConversationDetail = undefined
|
||||
mockShowMessageLogModal = false
|
||||
@ -273,34 +277,29 @@ describe('ConversationList', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should render chat rows and push the conversation id into the url when a row is clicked', () => {
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.CHAT } as any}
|
||||
logs={createLogs() as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
it('should render chat rows and push the conversation id into the url when a row is clicked', async () => {
|
||||
const { onUrlUpdate } = renderConversationList()
|
||||
|
||||
expect(screen.getByText('hello world')).toBeInTheDocument()
|
||||
expect(screen.getAllByText('formatted-1710000000')).toHaveLength(2)
|
||||
|
||||
fireEvent.click(screen.getByText('hello world'))
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/apps/app-1/logs?page=2&conversation_id=conversation-1', { scroll: false })
|
||||
expect(screen.getByTestId('drawer')).toBeInTheDocument()
|
||||
await waitFor(() => {
|
||||
expect(onUrlUpdate).toHaveBeenCalled()
|
||||
expect(screen.getByTestId('drawer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const update = onUrlUpdate.mock.calls.at(-1)![0]
|
||||
expect(update.searchParams.get('page')).toBe('2')
|
||||
expect(update.searchParams.get('conversation_id')).toBe('conversation-1')
|
||||
expect(update.options.history).toBe('push')
|
||||
})
|
||||
|
||||
it('should close the drawer, refresh, and clear modal flags', () => {
|
||||
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
|
||||
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.CHAT } as any}
|
||||
logs={createLogs() as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
it('should close the drawer, refresh, and clear modal flags', async () => {
|
||||
const { onUrlUpdate } = renderConversationList({
|
||||
searchParams: '?page=2&conversation_id=conversation-1',
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByText('close-drawer'))
|
||||
|
||||
@ -308,11 +307,18 @@ describe('ConversationList', () => {
|
||||
expect(mockSetShowPromptLogModal).toHaveBeenCalledWith(false)
|
||||
expect(mockSetShowAgentLogModal).toHaveBeenCalledWith(false)
|
||||
expect(mockSetShowMessageLogModal).toHaveBeenCalledWith(false)
|
||||
expect(mockReplace).toHaveBeenCalledWith('/apps/app-1/logs?page=2', { scroll: false })
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onUrlUpdate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
const update = onUrlUpdate.mock.calls.at(-1)![0]
|
||||
expect(update.searchParams.get('page')).toBe('2')
|
||||
expect(update.searchParams.has('conversation_id')).toBe(false)
|
||||
expect(update.options.history).toBe('replace')
|
||||
})
|
||||
|
||||
it('should render chat conversation details and submit feedback from the chat panel', async () => {
|
||||
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
|
||||
mockChatConversationDetail = {
|
||||
id: 'conversation-1',
|
||||
created_at: 1710000000,
|
||||
@ -355,13 +361,9 @@ describe('ConversationList', () => {
|
||||
mockShowMessageLogModal = true
|
||||
mockCurrentLogItem = { id: 'log-1' }
|
||||
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.CHAT } as any}
|
||||
logs={createLogs() as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
renderConversationList({
|
||||
searchParams: '?page=2&conversation_id=conversation-1',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchChatMessages).toHaveBeenCalledWith({
|
||||
@ -396,7 +398,6 @@ describe('ConversationList', () => {
|
||||
})
|
||||
|
||||
it('should render completion details and refetch after feedback updates', async () => {
|
||||
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
|
||||
mockCompletionConversationDetail = {
|
||||
id: 'conversation-1',
|
||||
created_at: 1710000000,
|
||||
@ -423,13 +424,11 @@ describe('ConversationList', () => {
|
||||
mockShowPromptLogModal = true
|
||||
mockCurrentLogItem = { id: 'log-2' }
|
||||
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.COMPLETION } as any}
|
||||
logs={createCompletionLogs() as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
renderConversationList({
|
||||
appDetail: { id: 'app-1', mode: AppModeEnum.COMPLETION } as any,
|
||||
logs: createCompletionLogs() as any,
|
||||
searchParams: '?page=2&conversation_id=conversation-1',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('text-generation')).toBeInTheDocument()
|
||||
@ -454,64 +453,61 @@ describe('ConversationList', () => {
|
||||
})
|
||||
|
||||
it('should render chatflow status cells and feedback counters for advanced chat logs', () => {
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT } as any}
|
||||
logs={{
|
||||
data: [
|
||||
{
|
||||
id: 'conversation-pending',
|
||||
name: 'Pending row',
|
||||
from_account_name: 'user-a',
|
||||
read_at: 1710000001,
|
||||
message_count: 3,
|
||||
status_count: { paused: 1, success: 0, failed: 0, partial_success: 0 },
|
||||
user_feedback_stats: { like: 2, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 1 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
{
|
||||
id: 'conversation-success',
|
||||
name: 'Success row',
|
||||
from_account_name: 'user-b',
|
||||
read_at: 1710000001,
|
||||
message_count: 4,
|
||||
status_count: { paused: 0, success: 4, failed: 0, partial_success: 0 },
|
||||
user_feedback_stats: { like: 0, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 0 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
{
|
||||
id: 'conversation-partial',
|
||||
name: 'Partial row',
|
||||
from_account_name: 'user-c',
|
||||
read_at: 1710000001,
|
||||
message_count: 5,
|
||||
status_count: { paused: 0, success: 3, failed: 0, partial_success: 1 },
|
||||
user_feedback_stats: { like: 0, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 0 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
{
|
||||
id: 'conversation-failure',
|
||||
name: 'Failure row',
|
||||
from_account_name: 'user-d',
|
||||
read_at: 1710000001,
|
||||
message_count: 1,
|
||||
status_count: { paused: 0, success: 0, failed: 2, partial_success: 0 },
|
||||
user_feedback_stats: { like: 0, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 0 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
],
|
||||
} as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
renderConversationList({
|
||||
appDetail: { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT } as any,
|
||||
logs: {
|
||||
data: [
|
||||
{
|
||||
id: 'conversation-pending',
|
||||
name: 'Pending row',
|
||||
from_account_name: 'user-a',
|
||||
read_at: 1710000001,
|
||||
message_count: 3,
|
||||
status_count: { paused: 1, success: 0, failed: 0, partial_success: 0 },
|
||||
user_feedback_stats: { like: 2, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 1 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
{
|
||||
id: 'conversation-success',
|
||||
name: 'Success row',
|
||||
from_account_name: 'user-b',
|
||||
read_at: 1710000001,
|
||||
message_count: 4,
|
||||
status_count: { paused: 0, success: 4, failed: 0, partial_success: 0 },
|
||||
user_feedback_stats: { like: 0, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 0 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
{
|
||||
id: 'conversation-partial',
|
||||
name: 'Partial row',
|
||||
from_account_name: 'user-c',
|
||||
read_at: 1710000001,
|
||||
message_count: 5,
|
||||
status_count: { paused: 0, success: 3, failed: 0, partial_success: 1 },
|
||||
user_feedback_stats: { like: 0, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 0 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
{
|
||||
id: 'conversation-failure',
|
||||
name: 'Failure row',
|
||||
from_account_name: 'user-d',
|
||||
read_at: 1710000001,
|
||||
message_count: 1,
|
||||
status_count: { paused: 0, success: 0, failed: 2, partial_success: 0 },
|
||||
user_feedback_stats: { like: 0, dislike: 0 },
|
||||
admin_feedback_stats: { like: 0, dislike: 0 },
|
||||
updated_at: 1710000000,
|
||||
created_at: 1710000000,
|
||||
},
|
||||
],
|
||||
} as any,
|
||||
})
|
||||
|
||||
expect(screen.getByText('Pending')).toBeInTheDocument()
|
||||
expect(screen.getByText('Success')).toBeInTheDocument()
|
||||
@ -522,7 +518,6 @@ describe('ConversationList', () => {
|
||||
})
|
||||
|
||||
it('should support annotation changes, modal closing, and paginated scroll loading in the detail drawer', async () => {
|
||||
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
|
||||
mockChatConversationDetail = {
|
||||
id: 'conversation-1',
|
||||
created_at: 1710000000,
|
||||
@ -568,13 +563,9 @@ describe('ConversationList', () => {
|
||||
has_more: false,
|
||||
})
|
||||
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.CHAT } as any}
|
||||
logs={createLogs() as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
renderConversationList({
|
||||
searchParams: '?page=2&conversation_id=conversation-1',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('chat-panel')).toBeInTheDocument()
|
||||
@ -609,7 +600,6 @@ describe('ConversationList', () => {
|
||||
})
|
||||
|
||||
it('should close the prompt log modal from completion detail drawers', async () => {
|
||||
mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1')
|
||||
mockCompletionConversationDetail = {
|
||||
id: 'conversation-1',
|
||||
created_at: 1710000000,
|
||||
@ -636,13 +626,11 @@ describe('ConversationList', () => {
|
||||
mockShowPromptLogModal = true
|
||||
mockCurrentLogItem = { id: 'log-2' }
|
||||
|
||||
render(
|
||||
<ConversationList
|
||||
appDetail={{ id: 'app-1', mode: AppModeEnum.COMPLETION } as any}
|
||||
logs={createCompletionLogs() as any}
|
||||
onRefresh={mockOnRefresh}
|
||||
/>,
|
||||
)
|
||||
renderConversationList({
|
||||
appDetail: { id: 'app-1', mode: AppModeEnum.COMPLETION } as any,
|
||||
logs: createCompletionLogs() as any,
|
||||
searchParams: '?page=2&conversation_id=conversation-1',
|
||||
})
|
||||
|
||||
expect(await screen.findByTestId('prompt-log-modal')).toBeInTheDocument()
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ import dayjs from 'dayjs'
|
||||
import timezone from 'dayjs/plugin/timezone'
|
||||
import utc from 'dayjs/plugin/utc'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import { parseAsString, useQueryState } from 'nuqs'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@ -33,7 +34,6 @@ import { WorkflowContextProvider } from '@/app/components/workflow/context'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import useTimestamp from '@/hooks/use-timestamp'
|
||||
import { usePathname, useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { fetchChatMessages, updateLogMessageAnnotations, updateLogMessageFeedbacks } from '@/service/log'
|
||||
import { AppSourceType } from '@/service/share'
|
||||
import { useChatConversationDetail, useCompletionConversationDetail } from '@/service/use-log'
|
||||
@ -46,7 +46,6 @@ import {
|
||||
applyAnnotationEdited,
|
||||
applyAnnotationRemoved,
|
||||
buildChatThreadState,
|
||||
buildConversationUrl,
|
||||
getCompletionMessageFiles,
|
||||
getConversationRowValues,
|
||||
getDetailVarList,
|
||||
@ -674,10 +673,7 @@ const ChatConversationDetailComp: FC<{ appId?: string, conversationId?: string }
|
||||
const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh }) => {
|
||||
const { t } = useTranslation()
|
||||
const { formatTime } = useTimestamp()
|
||||
const router = useRouter()
|
||||
const pathname = usePathname()
|
||||
const searchParams = useSearchParams()
|
||||
const conversationIdInUrl = searchParams.get('conversation_id') ?? undefined
|
||||
const [conversationIdInUrl, setConversationIdInUrl] = useQueryState('conversation_id', parseAsString)
|
||||
|
||||
const media = useBreakpoints()
|
||||
const isMobile = media === MediaType.mobile
|
||||
@ -697,8 +693,6 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
|
||||
const activeConversationId = conversationIdInUrl ?? pendingConversationIdRef.current ?? currentConversation?.id
|
||||
|
||||
const buildUrlWithConversation = useCallback((conversationId?: string) => buildConversationUrl(pathname, searchParams.toString(), conversationId), [pathname, searchParams])
|
||||
|
||||
const handleRowClick = useCallback((log: ConversationListItem) => {
|
||||
if (conversationIdInUrl === log.id) {
|
||||
if (!showDrawer)
|
||||
@ -717,8 +711,8 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
if (currentConversation?.id !== log.id)
|
||||
setCurrentConversation(undefined)
|
||||
|
||||
router.push(buildUrlWithConversation(log.id), { scroll: false })
|
||||
}, [buildUrlWithConversation, conversationIdInUrl, currentConversation, router, showDrawer])
|
||||
void setConversationIdInUrl(log.id, { history: 'push' })
|
||||
}, [conversationIdInUrl, currentConversation, setConversationIdInUrl, showDrawer])
|
||||
|
||||
const currentConversationId = currentConversation?.id
|
||||
|
||||
@ -755,7 +749,7 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
|
||||
if (pendingConversationCacheRef.current?.id === conversationIdInUrl || matchedConversation)
|
||||
pendingConversationCacheRef.current = undefined
|
||||
}, [conversationIdInUrl, currentConversation, isChatMode, logs?.data, showDrawer])
|
||||
}, [conversationIdInUrl, currentConversation, currentConversationId, logs?.data, showDrawer])
|
||||
|
||||
const onCloseDrawer = useCallback(() => {
|
||||
onRefresh()
|
||||
@ -769,8 +763,8 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
closingConversationIdRef.current = conversationIdInUrl ?? null
|
||||
|
||||
if (conversationIdInUrl)
|
||||
router.replace(buildUrlWithConversation(), { scroll: false })
|
||||
}, [buildUrlWithConversation, conversationIdInUrl, onRefresh, router, setShowAgentLogModal, setShowMessageLogModal, setShowPromptLogModal])
|
||||
void setConversationIdInUrl(null, { history: 'replace' })
|
||||
}, [conversationIdInUrl, onRefresh, setConversationIdInUrl, setShowAgentLogModal, setShowMessageLogModal, setShowPromptLogModal])
|
||||
|
||||
// Annotated data needs to be highlighted
|
||||
const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user