Merge branch 'main' into deploy/dev

This commit is contained in:
Stephen Zhou 2026-04-09 16:45:53 +08:00
commit d6c3df33c1
No known key found for this signature in database
175 changed files with 8400 additions and 3247 deletions

View File

@ -71,6 +71,13 @@ REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
REDIS_RETRY_RETRIES=3
REDIS_RETRY_BACKOFF_BASE=1.0
REDIS_RETRY_BACKOFF_CAP=10.0
REDIS_SOCKET_TIMEOUT=5.0
REDIS_SOCKET_CONNECT_TIMEOUT=5.0
REDIS_HEALTH_CHECK_INTERVAL=30
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis
@ -102,6 +109,7 @@ S3_BUCKET_NAME=your-bucket-name
S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
S3_REGION=your-region
S3_ADDRESS_STYLE=auto
# Workflow run and Conversation archive storage (S3-compatible)
ARCHIVE_STORAGE_ENABLED=false

View File

@ -117,6 +117,37 @@ class RedisConfig(BaseSettings):
default=None,
)
REDIS_RETRY_RETRIES: NonNegativeInt = Field(
description="Maximum number of retries per Redis command on "
"transient failures (ConnectionError, TimeoutError, socket.timeout)",
default=3,
)
REDIS_RETRY_BACKOFF_BASE: PositiveFloat = Field(
description="Base delay in seconds for exponential backoff between retries",
default=1.0,
)
REDIS_RETRY_BACKOFF_CAP: PositiveFloat = Field(
description="Maximum backoff delay in seconds between retries",
default=10.0,
)
REDIS_SOCKET_TIMEOUT: PositiveFloat | None = Field(
description="Socket timeout in seconds for Redis read/write operations",
default=5.0,
)
REDIS_SOCKET_CONNECT_TIMEOUT: PositiveFloat | None = Field(
description="Socket timeout in seconds for Redis connection establishment",
default=5.0,
)
REDIS_HEALTH_CHECK_INTERVAL: NonNegativeInt = Field(
description="Interval in seconds between Redis connection health checks (0 to disable)",
default=30,
)
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
@classmethod
def _empty_string_to_none_for_max_conns(cls, v):

View File

@ -48,11 +48,27 @@ class SavedMessageCreatePayload(BaseModel):
# --- Workflow schemas ---
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
# --- Audio schemas ---

View File

@ -34,9 +34,10 @@ from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService, ImportMode
from services.app_dsl_service import AppDslService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.dsl_entities import ImportMode
from services.entities.knowledge_entities.knowledge_entities import (
DataSource,
InfoList,

View File

@ -17,8 +17,9 @@ from fields.app_fields import (
)
from libs.login import current_account_with_tenant, login_required
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
from services.app_dsl_service import AppDslService
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.dsl_entities import ImportStatus
from services.feature_service import FeatureService
from .. import console_ns
@ -92,11 +93,13 @@ class AppImportApi(Resource):
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
match status:
case ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
case ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
return result.model_dump(mode="json"), 200
@console_ns.route("/apps/imports/<string:import_id>/confirm")

View File

@ -14,6 +14,7 @@ from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
from controllers.console import console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow_run import workflow_run_node_execution_model
@ -142,10 +143,6 @@ class PublishWorkflowPayload(BaseModel):
marked_comment: str | None = Field(default=None, max_length=100)
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class ConvertToWorkflowPayload(BaseModel):
name: str | None = None
icon_type: str | None = None
@ -153,18 +150,6 @@ class ConvertToWorkflowPayload(BaseModel):
icon_background: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DraftWorkflowTriggerRunPayload(BaseModel):
node_id: str

View File

@ -384,24 +384,27 @@ class VariableApi(Resource):
new_value = None
if raw_value is not None:
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
match variable.value_type:
case SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
case SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
case _:
pass
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@ -223,24 +223,27 @@ class RagPipelineVariableApi(Resource):
new_value = None
if raw_value is not None:
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
match variable.value_type:
case SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
case SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
case _:
pass
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@ -19,7 +19,7 @@ from fields.rag_pipeline_fields import (
)
from libs.login import current_account_with_tenant, login_required
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
from services.entities.dsl_entities import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@ -83,11 +83,13 @@ class RagPipelineImportApi(Resource):
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
match status:
case ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
case ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
@ -94,22 +95,6 @@ class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
original_document_id: str | None = None
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class NodeIdQuery(BaseModel):
node_id: str

View File

@ -168,12 +168,13 @@ class ConsoleWorkflowEventsApi(Resource):
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app.mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app.mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
match app.mode:
case AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
case AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
case _:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"

View File

@ -18,7 +18,8 @@ from controllers.inner_api.wraps import enterprise_inner_api_only
from extensions.ext_database import db
from models import Account, App
from models.account import AccountStatus
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
from services.app_dsl_service import AppDslService
from services.entities.dsl_entities import ImportMode, ImportStatus
class InnerAppDSLImportPayload(BaseModel):

View File

@ -138,12 +138,15 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
if auth_type == WebAppAuthType.PUBLIC:
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
raise WebAppAuthRequiredError("Please login as external user.")
elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
raise WebAppAuthRequiredError("Please login as internal user.")
match auth_type:
case WebAppAuthType.PUBLIC:
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
case WebAppAuthType.EXTERNAL:
if user_auth_type != "external":
raise WebAppAuthRequiredError("Please login as external user.")
case WebAppAuthType.INTERNAL:
if user_auth_type != "internal":
raise WebAppAuthRequiredError("Please login as internal user.")
end_user = None
if end_user_id:

View File

@ -72,12 +72,13 @@ class WorkflowEventsApi(WebApiResource):
app_mode = AppMode.value_of(app_model.mode)
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app_mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
match app_mode:
case AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
case AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
case _:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"

View File

@ -10,7 +10,7 @@ from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.variable_loader import VariableLoader
from graphon.variables.variables import Variable
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -363,7 +363,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
:return: List of conversation variables ready for use
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
existing_variables = self._load_existing_conversation_variables(session)
if not existing_variables:
@ -376,7 +376,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# Convert to Variable objects for use in the workflow
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:

View File

@ -16,7 +16,7 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
from graphon.nodes import BuiltinNodeTypes
from graphon.runtime import GraphRuntimeState
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -328,13 +328,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
@contextmanager
def _database_session(self):
"""Context manager for database sessions."""
with Session(db.engine, expire_on_commit=False) as session:
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
yield session
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""

View File

@ -2,7 +2,6 @@ import logging
import time
from typing import cast
from graphon.entities import GraphInitParams
from graphon.enums import WorkflowType
from graphon.graph import Graph
from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
@ -22,7 +21,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id
from core.workflow.node_factory import DifyGraphInitContext, DifyNodeFactory, get_default_root_node_id
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
@ -265,22 +264,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
# graph_config["nodes"] = real_run_nodes
# graph_config["edges"] = real_edges
# init graph
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
# Create explicit graph init context for Graph.init.
run_context = build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow.id,
graph_config=graph_config,
run_context=build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
),
run_context=run_context,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state,
)
if start_node_id is None:

View File

@ -7,7 +7,7 @@ from typing import Union
from graphon.entities import WorkflowStartReason
from graphon.enums import WorkflowExecutionStatus
from graphon.runtime import GraphRuntimeState
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -252,13 +252,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
@contextmanager
def _database_session(self):
"""Context manager for database sessions."""
with Session(db.engine, expire_on_commit=False) as session:
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
yield session
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""

View File

@ -3,7 +3,6 @@ import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
from graphon.entities import GraphInitParams
from graphon.entities.graph_config import NodeConfigDictAdapter
from graphon.entities.pause_reason import HumanInputRequired
from graphon.graph import Graph
@ -67,7 +66,12 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.rag.entities import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.node_factory import (
DifyGraphInitContext,
DifyNodeFactory,
get_default_root_node_id,
resolve_workflow_node_class,
)
from core.workflow.system_variables import (
build_bootstrap_variables,
default_system_variables,
@ -127,24 +131,25 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
# Create explicit graph init context for Graph.init.
run_context = build_dify_run_context(
tenant_id=tenant_id or "",
app_id=self._app_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow_id,
graph_config=graph_config,
run_context=build_dify_run_context(
tenant_id=tenant_id or "",
app_id=self._app_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
),
run_context=run_context,
call_depth=0,
)
# Use the provided graph_runtime_state for consistent state management
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state,
)
@ -289,22 +294,23 @@ class WorkflowBasedAppRunner:
typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs]
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
# Create explicit graph init context for Graph.init.
run_context = build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow.id,
graph_config=graph_config,
run_context=build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
run_context=run_context,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state,
)

View File

@ -1,6 +1,6 @@
from graphon.model_runtime.entities.llm_entities import LLMUsage
from sqlalchemy import update
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.entities.model_entities import ModelStatus
@ -57,37 +57,37 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
match system_configuration.current_quota_type:
case ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
session.execute(stmt)
session.commit()
case ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
case ProviderQuotaType.FREE:
with sessionmaker(bind=db.engine).begin() as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)

View File

@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -266,9 +266,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
event = message.event
if isinstance(event, QueueErrorEvent):
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
err = self.handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self.error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
@ -288,10 +287,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
answer=output_moderation_answer
)
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
# Save message
self._save_message(session=session, trace_manager=trace_manager)
session.commit()
message_end_resp = self._message_end_to_stream_response()
yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent):

View File

@ -40,41 +40,44 @@ def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, Upl
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
if message_file.url.startswith(("http://", "https://")):
match message_file.transfer_method:
case FileTransferMethod.REMOTE_URL:
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
else:
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0]
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
case FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
case FileTransferMethod.TOOL_FILE if message_file.url:
if message_file.url.startswith(("http://", "https://")):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
else:
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0]
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
extension = ".bin"
else:
tool_file_id = file_part
extension = ".bin"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
case FileTransferMethod.TOOL_FILE | FileTransferMethod.DATASOURCE_FILE:
pass
transfer_method_value = message_file.transfer_method.value
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""

View File

@ -187,15 +187,16 @@ def build_parameter_schema(
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> ToolArgumentsDict:
"""Prepare arguments based on app mode"""
if app.mode == AppMode.WORKFLOW:
return {"inputs": arguments}
elif app.mode == AppMode.COMPLETION:
return {"query": "", "inputs": arguments}
else:
# Chat modes - create a copy to avoid modifying original dict
args_copy = arguments.copy()
query = args_copy.pop("query", "")
return {"query": query, "inputs": args_copy}
match app.mode:
case AppMode.WORKFLOW:
return {"inputs": arguments}
case AppMode.COMPLETION:
return {"query": "", "inputs": arguments}
case _:
# Chat modes - create a copy to avoid modifying original dict
args_copy = arguments.copy()
query = args_copy.pop("query", "")
return {"query": query, "inputs": args_copy}
def extract_answer_from_response(app: App, response: Any) -> str:
@ -229,17 +230,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str:
def process_mapping_response(app: App, response: Mapping) -> str:
"""Process mapping response based on app mode"""
if app.mode in {
AppMode.ADVANCED_CHAT,
AppMode.COMPLETION,
AppMode.CHAT,
AppMode.AGENT_CHAT,
}:
return response.get("answer", "")
elif app.mode == AppMode.WORKFLOW:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode: " + str(app.mode))
match app.mode:
case AppMode.ADVANCED_CHAT | AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
return response.get("answer", "")
case AppMode.WORKFLOW:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
case _:
raise ValueError("Invalid app mode: " + str(app.mode))
def convert_input_form_to_parameters(

View File

@ -72,17 +72,18 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
conversation_id = conversation_id or ""
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}:
if not query:
raise ValueError("missing query")
match app.mode:
case AppMode.ADVANCED_CHAT | AppMode.AGENT_CHAT | AppMode.CHAT:
if not query:
raise ValueError("missing query")
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
elif app.mode == AppMode.WORKFLOW:
return cls.invoke_workflow_app(app, user, stream, inputs, files)
elif app.mode == AppMode.COMPLETION:
return cls.invoke_completion_app(app, user, stream, inputs, files)
raise ValueError("unexpected app type")
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
case AppMode.WORKFLOW:
return cls.invoke_workflow_app(app, user, stream, inputs, files)
case AppMode.COMPLETION:
return cls.invoke_completion_app(app, user, stream, inputs, files)
case _:
raise ValueError("unexpected app type")
@classmethod
def invoke_chat_app(
@ -98,60 +99,61 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke chat app
"""
if app.mode == AppMode.ADVANCED_CHAT:
workflow = app.workflow
if not workflow:
match app.mode:
case AppMode.ADVANCED_CHAT:
workflow = app.workflow
if not workflow:
raise ValueError("unexpected app type")
pause_config = PauseStateLayerConfig(
session_factory=db.engine,
state_owner_user_id=workflow.created_by,
)
return AdvancedChatAppGenerator().generate(
app_model=app,
workflow=workflow,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
workflow_run_id=str(uuid.uuid4()),
streaming=stream,
pause_state_config=pause_config,
)
case AppMode.AGENT_CHAT:
return AgentChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
case AppMode.CHAT:
return ChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
case _:
raise ValueError("unexpected app type")
pause_config = PauseStateLayerConfig(
session_factory=db.engine,
state_owner_user_id=workflow.created_by,
)
return AdvancedChatAppGenerator().generate(
app_model=app,
workflow=workflow,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
workflow_run_id=str(uuid.uuid4()),
streaming=stream,
pause_state_config=pause_config,
)
elif app.mode == AppMode.AGENT_CHAT:
return AgentChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
elif app.mode == AppMode.CHAT:
return ChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
else:
raise ValueError("unexpected app type")
@classmethod
def invoke_workflow_app(
cls,

View File

@ -961,36 +961,37 @@ class ProviderManager:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=trail_pool.quota_used,
quota_limit=trail_pool.quota_limit,
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
match provider_quota.quota_type:
case ProviderQuotaType.TRIAL if trail_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=trail_pool.quota_used,
quota_limit=trail_pool.quota_limit,
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
case ProviderQuotaType.PAID if paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
else:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
case _:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configurations.append(quota_configuration)

View File

@ -37,11 +37,12 @@ class AnalyticdbVector(BaseVector):
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self.analyticdb_vector._create_collection_if_not_exists(dimension)
self.analyticdb_vector.create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
self.analyticdb_vector.add_texts(documents, embeddings)
return []
def text_exists(self, id: str) -> bool:
return self.analyticdb_vector.text_exists(id)

View File

@ -123,7 +123,7 @@ class AnalyticdbVectorOpenAPI:
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
def create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException

View File

@ -1,5 +1,6 @@
import json
import uuid
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any
@ -74,7 +75,7 @@ class AnalyticdbVectorBySql:
)
@contextmanager
def _get_cursor(self):
def _get_cursor(self) -> Iterator[Any]:
assert self.pool is not None, "Connection pool is not initialized"
conn = self.pool.getconn()
cur = conn.cursor()
@ -130,7 +131,7 @@ class AnalyticdbVectorBySql:
)
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
def create_collection_if_not_exists(self, embedding_dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):

View File

@ -2,7 +2,7 @@ import json
from typing import Any, TypedDict
import chromadb
from chromadb import QueryResult, Settings
from chromadb import QueryResult, Settings # pyright: ignore[reportPrivateImportUsage]
from pydantic import BaseModel
from configs import dify_config
@ -106,14 +106,15 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
document_ids_filter = kwargs.get("document_ids_filter")
results: QueryResult
if document_ids_filter:
results: QueryResult = collection.query(
results = collection.query(
query_embeddings=query_vector,
n_results=kwargs.get("top_k", 4),
where={"document_id": {"$in": document_ids_filter}}, # type: ignore
)
else:
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
results = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# Check if results contain data
@ -165,8 +166,8 @@ class ChromaVectorFactory(AbstractVectorFactory):
config=ChromaConfig(
host=dify_config.CHROMA_HOST or "",
port=dify_config.CHROMA_PORT,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, # pyright: ignore[reportPrivateImportUsage]
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, # pyright: ignore[reportPrivateImportUsage]
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
),

View File

@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator
from sqlalchemy import Float, create_engine, insert, select, text
from sqlalchemy import text as sql_text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column
from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker
from configs import dify_config
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
@ -55,9 +55,8 @@ class PGVectoRS(BaseVector):
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self._client = create_engine(self._url)
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
session.commit()
self._fields: list[str] = []
class _Table(CollectionORM):
@ -88,7 +87,7 @@ class PGVectoRS(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} (
id UUID PRIMARY KEY,
@ -111,12 +110,11 @@ class PGVectoRS(BaseVector):
$$);
""")
session.execute(index_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
pks = []
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
for document, embedding in zip(documents, embeddings):
pk = uuid4()
session.execute(
@ -128,7 +126,6 @@ class PGVectoRS(BaseVector):
),
)
pks.append(pk)
session.commit()
return pks
@ -145,10 +142,9 @@ class PGVectoRS(BaseVector):
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {"ids": ids})
session.commit()
def delete_by_ids(self, ids: list[str]):
with Session(self._client) as session:
@ -159,15 +155,13 @@ class PGVectoRS(BaseVector):
if result:
ids = [item[0] for item in result]
if ids:
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {"ids": ids})
session.commit()
def delete(self):
with Session(self._client) as session:
with sessionmaker(bind=self._client).begin() as session:
session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}"))
session.commit()
def text_exists(self, id: str) -> bool:
with Session(self._client) as session:

View File

@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
import qdrant_client
from flask import current_app
@ -32,7 +32,6 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetCollectionBinding
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
@ -180,7 +179,7 @@ class QdrantVector(BaseVector):
for batch_ids, points in self._generate_rest_batches(
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
):
self._client.upsert(collection_name=self._collection_name, points=points)
self._client.upsert(collection_name=self._collection_name, points=cast("common_types.Points", points))
added_ids.extend(batch_ids)
return added_ids
@ -472,7 +471,7 @@ class QdrantVector(BaseVector):
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client._load()
self._client._load() # pyright: ignore[reportPrivateUsage]
@classmethod
def _document_from_scored_point(

View File

@ -26,7 +26,7 @@ from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
Base = declarative_base() # type: Any
Base: Any = declarative_base()
class RelytConfig(BaseModel):

View File

@ -19,12 +19,15 @@ class UnstructuredWordExtractor(BaseExtractor):
def extract(self) -> list[Document]:
from unstructured.__version__ import __version__ as __unstructured_version__
from unstructured.file_utils.filetype import FileType, detect_filetype
from unstructured.file_utils.filetype import ( # pyright: ignore[reportPrivateImportUsage]
FileType,
detect_filetype,
)
unstructured_version = tuple(int(x) for x in __unstructured_version__.split("."))
# check the file extension
try:
import magic # noqa: F401
import magic # noqa: F401 # pyright: ignore[reportUnusedImport]
is_doc = detect_filetype(self._file_path) == FileType.DOC
except ImportError:

View File

@ -15,7 +15,7 @@ from graphon.model_runtime.entities.message_entities import PromptMessage, Promp
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import (
DatasetEntity,
@ -884,7 +884,7 @@ class DatasetRetrieval:
self._send_trace_task(message_id, documents, timer)
return
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
# Collect all document_ids and batch fetch DatasetDocuments
document_ids = {
doc.metadata["document_id"]
@ -975,7 +975,6 @@ class DatasetRetrieval:
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
session.commit()
self._send_trace_task(message_id, documents, timer)

View File

@ -75,22 +75,27 @@ class ToolProviderApiEntity(BaseModel):
parameter.pop("input_schema", None)
# -------------
optional_fields = self.optional_field("server_url", self.server_url)
if self.type == ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
optional_fields.update(
self.optional_field(
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
match self.type:
case ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
optional_fields.update(
self.optional_field(
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
)
)
)
optional_fields.update(
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
)
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
elif self.type == ToolProviderType.WORKFLOW:
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
optional_fields.update(
self.optional_field(
"authentication", self.authentication.model_dump() if self.authentication else None
)
)
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
case ToolProviderType.WORKFLOW:
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
case _:
pass
return {
"id": self.id,
"author": self.author,

View File

@ -11,6 +11,7 @@ from uuid import uuid4
import httpx
from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type
from sqlalchemy import select
from configs import dify_config
from core.db.session_factory import session_factory
@ -166,13 +167,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with session_factory.create_session() as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
ToolFile.id == id,
)
.first()
)
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == id).limit(1))
if not tool_file:
return None
@ -190,13 +185,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with session_factory.create_session() as session:
message_file: MessageFile | None = (
session.query(MessageFile)
.where(
MessageFile.id == id,
)
.first()
)
message_file: MessageFile | None = session.scalar(select(MessageFile).where(MessageFile.id == id).limit(1))
# Check if message_file is not None
if message_file is not None:
@ -210,13 +199,7 @@ class ToolFileManager:
else:
tool_file_id = None
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
ToolFile.id == tool_file_id,
)
.first()
)
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
if not tool_file:
return None
@ -234,13 +217,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with session_factory.create_session() as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
ToolFile.id == tool_file_id,
)
.first()
)
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
if not tool_file:
return None, None

View File

@ -205,16 +205,160 @@ class ToolManager:
:return: the tool
"""
if provider_type == ToolProviderType.BUILT_IN:
# check if the builtin tool need credentials
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
match provider_type:
case ToolProviderType.BUILT_IN:
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
builtin_tool = provider_controller.get_tool(tool_name)
if not builtin_tool:
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
builtin_tool = provider_controller.get_tool(tool_name)
if not builtin_tool:
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id)
if is_valid_uuid(credential_id):
try:
builtin_provider_stmt = select(BuiltinToolProvider).where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
builtin_provider = db.session.scalar(builtin_provider_stmt)
except Exception as e:
builtin_provider = None
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
if builtin_provider is None:
with Session(db.engine) as session:
builtin_provider = session.scalar(
sa.select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = db.session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.limit(1)
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=builtin_provider.id,
provider=provider_id,
credential_type=PluginCredentialType.TOOL,
check_existence=False,
)
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
],
cache=ToolProviderCredentialsCache(
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from core.plugin.impl.oauth import OAuthHandler
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = json.dumps(
encrypter.encrypt(refreshed_credentials.credentials)
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials
cache.delete()
if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials=dict(decrypted_credentials),
credential_type=builtin_provider.credential_type,
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
case ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
controller=api_provider,
)
return api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials=dict(encrypter.decrypt(credentials)),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
case ToolProviderType.WORKFLOW:
workflow_provider_stmt = select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
workflow_provider = session.scalar(workflow_provider_stmt)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
@ -223,177 +367,28 @@ class ToolManager:
tool_invoke_from=tool_invoke_from,
)
)
builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id)
# get specific credentials
if is_valid_uuid(credential_id):
try:
builtin_provider_stmt = select(BuiltinToolProvider).where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
builtin_provider = db.session.scalar(builtin_provider_stmt)
except Exception as e:
builtin_provider = None
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
# if the provider has been deleted, raise an error
if builtin_provider is None:
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
# fallback to the default provider
if builtin_provider is None:
# use the default provider
with Session(db.engine) as session:
builtin_provider = session.scalar(
sa.select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = db.session.scalar(
select(BuiltinToolProvider)
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.limit(1)
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# check if the credential is allowed to be used
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=builtin_provider.id,
provider=provider_id,
credential_type=PluginCredentialType.TOOL,
check_existence=False,
)
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
],
cache=ToolProviderCredentialsCache(
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)
# decrypt the credentials
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
# check if the credentials is expired
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from core.plugin.impl.oauth import OAuthHandler
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
# refresh the credentials
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = json.dumps(
encrypter.encrypt(refreshed_credentials.credentials)
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials
cache.delete()
return builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials=dict(decrypted_credentials),
credential_type=builtin_provider.credential_type,
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
controller=api_provider,
)
return api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials=dict(encrypter.decrypt(credentials)),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider_stmt = select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
workflow_provider = session.scalar(workflow_provider_stmt)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
elif provider_type == ToolProviderType.PLUGIN:
plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
runtime = getattr(plugin_tool, "runtime", None)
if runtime is not None:
runtime.user_id = user_id
runtime.invoke_from = invoke_from
runtime.tool_invoke_from = tool_invoke_from
return plugin_tool
elif provider_type == ToolProviderType.MCP:
mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
runtime = getattr(mcp_tool, "runtime", None)
if runtime is not None:
runtime.user_id = user_id
runtime.invoke_from = invoke_from
runtime.tool_invoke_from = tool_invoke_from
return mcp_tool
else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
case ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
case ToolProviderType.PLUGIN:
plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
runtime = getattr(plugin_tool, "runtime", None)
if runtime is not None:
runtime.user_id = user_id
runtime.invoke_from = invoke_from
runtime.tool_invoke_from = tool_invoke_from
return plugin_tool
case ToolProviderType.MCP:
mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
runtime = getattr(mcp_tool, "runtime", None)
if runtime is not None:
runtime.user_id = user_id
runtime.invoke_from = invoke_from
runtime.tool_invoke_from = tool_invoke_from
return mcp_tool
case ToolProviderType.DATASET_RETRIEVAL:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
case _:
raise ToolProviderNotFoundError(f"provider type {provider_type} not found")
@classmethod
def get_agent_tool_runtime(
@ -1027,31 +1022,31 @@ class ToolManager:
:param provider_id: the id of the provider
:return:
"""
provider_type = provider_type
provider_id = provider_id
if provider_type == ToolProviderType.BUILT_IN:
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
if isinstance(provider, PluginToolProviderController):
match provider_type:
case ToolProviderType.BUILT_IN:
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
if isinstance(provider, PluginToolProviderController):
try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
return cls.generate_builtin_tool_icon_url(provider_id)
case ToolProviderType.API:
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
case ToolProviderType.WORKFLOW:
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
case ToolProviderType.PLUGIN:
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
return cls.generate_builtin_tool_icon_url(provider_id)
elif provider_type == ToolProviderType.API:
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
elif provider_type == ToolProviderType.WORKFLOW:
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
elif provider_type == ToolProviderType.PLUGIN:
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
raise ValueError(f"plugin provider {provider_id} not found")
elif provider_type == ToolProviderType.MCP:
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
else:
raise ValueError(f"provider type {provider_type} not found")
case ToolProviderType.MCP:
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
case ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL:
raise ValueError(f"provider type {provider_type} not found")
case _:
raise ValueError(f"provider type {provider_type} not found")
@classmethod
def _convert_tool_parameters_type(

View File

@ -7,14 +7,13 @@ from sqlalchemy import select
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.entities import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment

View File

@ -1,11 +1,10 @@
from typing import NotRequired, TypedDict, cast
from typing import cast
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.entities import DocumentContext, RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RetrievalDocument
@ -17,18 +16,6 @@ from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
class DefaultRetrievalModelDict(TypedDict):
search_method: RetrievalMethod
reranking_enable: bool
reranking_model: RerankingModelDict
reranking_mode: NotRequired[str]
weights: NotRequired[WeightsDict | None]
score_threshold: NotRequired[float]
top_k: int
score_threshold_enabled: bool
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,

View File

@ -19,5 +19,18 @@ def remove_leading_symbols(text: str) -> str:
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
pattern = re.compile(
r"""
^
(?:
[\u2000-\u2025] # General Punctuation: spaces, quotes, dashes
| [\u2027-\u206F] # General Punctuation: ellipsis, underscores, etc.
| [\u2E00-\u2E7F] # Supplemental Punctuation: medieval, ancient marks
| [\u3000-\u300F] # CJK Punctuation: 、。〃「」『》』 (excludes 【】)
| [\u3012-\u303F] # CJK Punctuation: 〖〗〔〕〘〙〚〛〜 etc.
| ["#$%&'()*+,./:;<=>?@^_`~] # ASCII punctuation (excludes []【】)
)+
""",
re.VERBOSE,
)
return re.sub(pattern, "", text)

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping
from graphon.variables.input_entities import VariableEntity, VariableEntityType
from pydantic import Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@ -96,10 +97,10 @@ class WorkflowToolProviderController(ToolProviderController):
:param app: the app
:return: the tool
"""
workflow: Workflow | None = (
session.query(Workflow)
workflow: Workflow | None = session.scalar(
select(Workflow)
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first()
.limit(1)
)
if not workflow:
@ -217,13 +218,13 @@ class WorkflowToolProviderController(ToolProviderController):
return self.tools
with Session(db.engine, expire_on_commit=False) as session, session.begin():
db_provider: WorkflowToolProvider | None = (
session.query(WorkflowToolProvider)
db_provider: WorkflowToolProvider | None = session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == self.provider_id,
)
.first()
.limit(1)
)
if not db_provider:

View File

@ -305,14 +305,15 @@ class WorkflowTool(Tool):
"transfer_method": file.transfer_method.value,
"type": file.type.value,
}
if file.transfer_method == FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = resolve_file_record_id(file.reference)
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = resolve_file_record_id(file.reference)
elif file.transfer_method == FileTransferMethod.DATASOURCE_FILE:
file_dict["datasource_file_id"] = resolve_file_record_id(file.reference)
elif file.transfer_method == FileTransferMethod.REMOTE_URL:
file_dict["url"] = file.generate_url()
match file.transfer_method:
case FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = resolve_file_record_id(file.reference)
case FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = resolve_file_record_id(file.reference)
case FileTransferMethod.DATASOURCE_FILE:
file_dict["datasource_file_id"] = resolve_file_record_id(file.reference)
case FileTransferMethod.REMOTE_URL:
file_dict["url"] = file.generate_url()
files.append(file_dict)
except Exception:
@ -357,8 +358,11 @@ class WorkflowTool(Tool):
def _update_file_mapping(self, file_dict: dict):
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
if transfer_method == FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_id
elif transfer_method == FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = file_id
match transfer_method:
case FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_id
case FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = file_id
case FileTransferMethod.REMOTE_URL | FileTransferMethod.DATASOURCE_FILE:
pass
return file_dict

View File

@ -1,6 +1,7 @@
import importlib
import pkgutil
from collections.abc import Callable, Iterator, Mapping, MutableMapping
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Any, cast, final, override
@ -67,6 +68,31 @@ _START_NODE_TYPES: frozenset[NodeType] = frozenset(
)
@dataclass(frozen=True, slots=True)
class DifyGraphInitContext:
"""Explicit graph-init values owned by the workflow layer.
Dify is gradually removing direct `GraphInitParams` construction from its
production call sites. Keep the translation here until `graphon` exposes an
equivalent explicit API.
"""
workflow_id: str
graph_config: Mapping[str, Any]
run_context: Mapping[str, Any]
call_depth: int
def to_graph_init_params(self) -> "GraphInitParams":
from graphon.entities import GraphInitParams
return GraphInitParams(
workflow_id=self.workflow_id,
graph_config=self.graph_config,
run_context=self.run_context,
call_depth=self.call_depth,
)
def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None:
package = importlib.import_module(package_name)
for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
@ -237,6 +263,19 @@ class DifyNodeFactory(NodeFactory):
Default implementation of NodeFactory that resolves node classes from the live registry.
"""
@classmethod
def from_graph_init_context(
cls,
*,
graph_init_context: DifyGraphInitContext,
graph_runtime_state: "GraphRuntimeState",
) -> "DifyNodeFactory":
"""Bridge Dify's explicit init context into the current `graphon` API."""
return cls(
graph_init_params=graph_init_context.to_graph_init_params(),
graph_runtime_state=graph_runtime_state,
)
def __init__(
self,
graph_init_params: "GraphInitParams",

View File

@ -29,7 +29,7 @@ class TriggerWebhookNode(Node[WebhookData]):
def post_init(self) -> None:
from core.workflow.node_runtime import DifyFileReferenceFactory
self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context)
self._file_reference_factory = DifyFileReferenceFactory(self.run_context)
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -155,24 +155,25 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[param_name] = raw_data
continue
if param_type == SegmentType.FILE:
# Get File object (already processed by webhook controller)
files = webhook_data.get("files", {})
if files and isinstance(files, dict):
file = files.get(param_name)
if file and isinstance(file, dict):
file_var = self.generate_file_var(param_name, file)
if file_var:
outputs[param_name] = file_var
match param_type:
case SegmentType.FILE:
# Get File object (already processed by webhook controller)
files = webhook_data.get("files", {})
if files and isinstance(files, dict):
file = files.get(param_name)
if file and isinstance(file, dict):
file_var = self.generate_file_var(param_name, file)
if file_var:
outputs[param_name] = file_var
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
outputs[param_name] = files
else:
# Get regular body parameter
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
case _:
# Get regular body parameter
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
# Include raw webhook data for debugging/advanced use
outputs["_webhook_raw"] = webhook_data

View File

@ -24,7 +24,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di
from core.app.file_access import DatabaseFileAccessController
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer
from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class
from core.workflow.node_factory import (
DifyGraphInitContext,
DifyNodeFactory,
is_start_node_type,
resolve_workflow_node_class,
)
from core.workflow.system_variables import (
default_system_variables,
get_node_creation_preload_selectors,
@ -251,17 +256,18 @@ class WorkflowEntry:
node_version = str(node_config_data.version)
node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version)
# init graph init params and runtime state
graph_init_params = GraphInitParams(
# init graph context and runtime state
run_context = build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow.id,
graph_config=workflow.graph_dict,
run_context=build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
run_context=run_context,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
@ -313,8 +319,8 @@ class WorkflowEntry:
)
# init workflow run state
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(node_config)
@ -409,17 +415,18 @@ class WorkflowEntry:
variable_pool = VariablePool()
add_variables_to_pool(variable_pool, default_system_variables())
# init graph init params and runtime state
graph_init_params = GraphInitParams(
# init graph context and runtime state
run_context = build_dify_run_context(
tenant_id=tenant_id,
app_id="",
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_context = DifyGraphInitContext(
workflow_id="",
graph_config=graph_dict,
run_context=build_dify_run_context(
tenant_id=tenant_id,
app_id="",
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
run_context=run_context,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
@ -430,8 +437,8 @@ class WorkflowEntry:
# init workflow run state
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(node_config)

View File

@ -68,46 +68,49 @@ class EnterpriseMetricHandler:
# Route to appropriate handler based on case
case = envelope.case
if case == TelemetryCase.APP_CREATED:
self._on_app_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
elif case == TelemetryCase.APP_UPDATED:
self._on_app_updated(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
elif case == TelemetryCase.APP_DELETED:
self._on_app_deleted(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
elif case == TelemetryCase.FEEDBACK_CREATED:
self._on_feedback_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
elif case == TelemetryCase.MESSAGE_RUN:
self._on_message_run(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
elif case == TelemetryCase.TOOL_EXECUTION:
self._on_tool_execution(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
elif case == TelemetryCase.MODERATION_CHECK:
self._on_moderation_check(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
elif case == TelemetryCase.SUGGESTED_QUESTION:
self._on_suggested_question(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
elif case == TelemetryCase.DATASET_RETRIEVAL:
self._on_dataset_retrieval(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
elif case == TelemetryCase.GENERATE_NAME:
self._on_generate_name(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
elif case == TelemetryCase.PROMPT_GENERATION:
self._on_prompt_generation(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
else:
logger.warning(
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
case,
envelope.tenant_id,
envelope.event_id,
)
match case:
case TelemetryCase.APP_CREATED:
self._on_app_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
case TelemetryCase.APP_UPDATED:
self._on_app_updated(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
case TelemetryCase.APP_DELETED:
self._on_app_deleted(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
case TelemetryCase.FEEDBACK_CREATED:
self._on_feedback_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
case TelemetryCase.MESSAGE_RUN:
self._on_message_run(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
case TelemetryCase.TOOL_EXECUTION:
self._on_tool_execution(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
case TelemetryCase.MODERATION_CHECK:
self._on_moderation_check(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
case TelemetryCase.SUGGESTED_QUESTION:
self._on_suggested_question(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
case TelemetryCase.DATASET_RETRIEVAL:
self._on_dataset_retrieval(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
case TelemetryCase.GENERATE_NAME:
self._on_generate_name(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
case TelemetryCase.PROMPT_GENERATION:
self._on_prompt_generation(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
case TelemetryCase.WORKFLOW_RUN | TelemetryCase.NODE_EXECUTION | TelemetryCase.DRAFT_NODE_EXECUTION:
pass
case _:
logger.warning(
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
case,
envelope.tenant_id,
envelope.event_id,
)
def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool:
"""Check if this event has already been processed.

View File

@ -7,10 +7,12 @@ from typing import TYPE_CHECKING, Any, Union
import redis
from redis import RedisError
from redis.backoff import ExponentialWithJitterBackoff # type: ignore
from redis.cache import CacheConfig
from redis.client import PubSub
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
from redis.retry import Retry
from redis.sentinel import Sentinel
from configs import dify_config
@ -158,8 +160,41 @@ def _get_cache_configuration() -> CacheConfig | None:
return CacheConfig()
def _get_retry_policy() -> Retry:
"""Build the shared retry policy for Redis connections."""
return Retry(
backoff=ExponentialWithJitterBackoff(
base=dify_config.REDIS_RETRY_BACKOFF_BASE,
cap=dify_config.REDIS_RETRY_BACKOFF_CAP,
),
retries=dify_config.REDIS_RETRY_RETRIES,
)
def _get_connection_health_params() -> dict[str, Any]:
"""Get connection health and retry parameters for standalone and Sentinel Redis clients."""
return {
"retry": _get_retry_policy(),
"socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT,
"socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
"health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL,
}
def _get_cluster_connection_health_params() -> dict[str, Any]:
"""Get retry and timeout parameters for Redis Cluster clients.
RedisCluster does not support ``health_check_interval`` as a constructor
keyword (it is silently stripped by ``cleanup_kwargs``), so it is excluded
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
are passed through.
"""
params = _get_connection_health_params()
return {k: v for k, v in params.items() if k != "health_check_interval"}
def _get_base_redis_params() -> dict[str, Any]:
"""Get base Redis connection parameters."""
"""Get base Redis connection parameters including retry and health policy."""
return {
"username": dify_config.REDIS_USERNAME,
"password": dify_config.REDIS_PASSWORD or None,
@ -169,6 +204,7 @@ def _get_base_redis_params() -> dict[str, Any]:
"decode_responses": False,
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
"cache_config": _get_cache_configuration(),
**_get_connection_health_params(),
}
@ -215,6 +251,7 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
"password": dify_config.REDIS_CLUSTERS_PASSWORD,
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
"cache_config": _get_cache_configuration(),
**_get_cluster_connection_health_params(),
}
if dify_config.REDIS_MAX_CONNECTIONS:
cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
@ -226,7 +263,8 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
"""Create standalone Redis client."""
connection_class, ssl_kwargs = _get_ssl_configuration()
redis_params.update(
params = {**redis_params}
params.update(
{
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
@ -235,28 +273,31 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
)
if dify_config.REDIS_MAX_CONNECTIONS:
redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
if ssl_kwargs:
redis_params.update(ssl_kwargs)
params.update(ssl_kwargs)
pool = redis.ConnectionPool(**redis_params)
pool = redis.ConnectionPool(**params)
client: redis.Redis = redis.Redis(connection_pool=pool)
return client
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster:
max_conns = dify_config.REDIS_MAX_CONNECTIONS
if use_clusters:
if max_conns:
return RedisCluster.from_url(pubsub_url, max_connections=max_conns)
else:
return RedisCluster.from_url(pubsub_url)
if use_clusters:
health_params = _get_cluster_connection_health_params()
kwargs: dict[str, Any] = {**health_params}
if max_conns:
kwargs["max_connections"] = max_conns
return RedisCluster.from_url(pubsub_url, **kwargs)
health_params = _get_connection_health_params()
kwargs = {**health_params}
if max_conns:
return redis.Redis.from_url(pubsub_url, max_connections=max_conns)
else:
return redis.Redis.from_url(pubsub_url)
kwargs["max_connections"] = max_conns
return redis.Redis.from_url(pubsub_url, **kwargs)
def init_app(app: DifyApp):

View File

@ -813,56 +813,32 @@ class AppModelConfig(TypeBase):
"file_upload": self.file_upload_dict,
}
@staticmethod
def _dump_optional(value: Any) -> str | None:
return json.dumps(value) if value else None
def from_model_config_dict(self, model_config: AppModelConfigDict):
self.opening_statement = model_config.get("opening_statement")
self.suggested_questions = (
json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None
)
self.suggested_questions_after_answer = (
json.dumps(model_config.get("suggested_questions_after_answer"))
if model_config.get("suggested_questions_after_answer")
else None
)
self.speech_to_text = (
json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None
)
self.text_to_speech = (
json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None
)
self.more_like_this = (
json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None
)
self.sensitive_word_avoidance = (
json.dumps(model_config.get("sensitive_word_avoidance"))
if model_config.get("sensitive_word_avoidance")
else None
)
self.external_data_tools = (
json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None
)
self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None
self.user_input_form = (
json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None
self.suggested_questions = self._dump_optional(model_config.get("suggested_questions"))
self.suggested_questions_after_answer = self._dump_optional(
model_config.get("suggested_questions_after_answer")
)
self.speech_to_text = self._dump_optional(model_config.get("speech_to_text"))
self.text_to_speech = self._dump_optional(model_config.get("text_to_speech"))
self.more_like_this = self._dump_optional(model_config.get("more_like_this"))
self.sensitive_word_avoidance = self._dump_optional(model_config.get("sensitive_word_avoidance"))
self.external_data_tools = self._dump_optional(model_config.get("external_data_tools"))
self.model = self._dump_optional(model_config.get("model"))
self.user_input_form = self._dump_optional(model_config.get("user_input_form"))
self.dataset_query_variable = model_config.get("dataset_query_variable")
self.pre_prompt = model_config.get("pre_prompt")
self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None
self.retriever_resource = (
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
)
self.agent_mode = self._dump_optional(model_config.get("agent_mode"))
self.retriever_resource = self._dump_optional(model_config.get("retriever_resource"))
self.prompt_type = PromptType(model_config.get("prompt_type", "simple"))
self.chat_prompt_config = (
json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None
)
self.completion_prompt_config = (
json.dumps(model_config.get("completion_prompt_config"))
if model_config.get("completion_prompt_config")
else None
)
self.dataset_configs = (
json.dumps(model_config.get("dataset_configs")) if model_config.get("dataset_configs") else None
)
self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None
self.chat_prompt_config = self._dump_optional(model_config.get("chat_prompt_config"))
self.completion_prompt_config = self._dump_optional(model_config.get("completion_prompt_config"))
self.dataset_configs = self._dump_optional(model_config.get("dataset_configs"))
self.file_upload = self._dump_optional(model_config.get("file_upload"))
return self
@ -1632,52 +1608,53 @@ class Message(Base):
files: list[File] = []
for message_file in message_files:
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if message_file.upload_file_id is None:
raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id")
file = file_factory.build_from_mapping(
mapping={
match message_file.transfer_method:
case FileTransferMethod.LOCAL_FILE:
if message_file.upload_file_id is None:
raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id")
file = file_factory.build_from_mapping(
mapping={
"id": message_file.id,
"type": message_file.type,
"transfer_method": message_file.transfer_method,
"upload_file_id": message_file.upload_file_id,
},
tenant_id=current_app.tenant_id,
access_controller=_get_file_access_controller(),
)
case FileTransferMethod.REMOTE_URL:
if message_file.url is None:
raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url")
file = file_factory.build_from_mapping(
mapping={
"id": message_file.id,
"type": message_file.type,
"transfer_method": message_file.transfer_method,
"upload_file_id": message_file.upload_file_id,
"url": message_file.url,
},
tenant_id=current_app.tenant_id,
access_controller=_get_file_access_controller(),
)
case FileTransferMethod.TOOL_FILE:
if message_file.upload_file_id is None:
assert message_file.url is not None
message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0]
mapping = {
"id": message_file.id,
"type": message_file.type,
"transfer_method": message_file.transfer_method,
"upload_file_id": message_file.upload_file_id,
},
tenant_id=current_app.tenant_id,
access_controller=_get_file_access_controller(),
)
elif message_file.transfer_method == FileTransferMethod.REMOTE_URL:
if message_file.url is None:
raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url")
file = file_factory.build_from_mapping(
mapping={
"id": message_file.id,
"type": message_file.type,
"transfer_method": message_file.transfer_method,
"upload_file_id": message_file.upload_file_id,
"url": message_file.url,
},
tenant_id=current_app.tenant_id,
access_controller=_get_file_access_controller(),
)
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE:
if message_file.upload_file_id is None:
assert message_file.url is not None
message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0]
mapping = {
"id": message_file.id,
"type": message_file.type,
"transfer_method": message_file.transfer_method,
"tool_file_id": message_file.upload_file_id,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=current_app.tenant_id,
access_controller=_get_file_access_controller(),
)
else:
raise ValueError(
f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}"
)
"tool_file_id": message_file.upload_file_id,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=current_app.tenant_id,
access_controller=_get_file_access_controller(),
)
case FileTransferMethod.DATASOURCE_FILE:
raise ValueError(
f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}"
)
files.append(file)
result = cast(

View File

@ -1625,21 +1625,22 @@ class WorkflowDraftVariable(Base):
# Rebuild them through the file factory so tenant ownership, signed URLs,
# and storage-backed metadata come from canonical records instead of the
# serialized JSON blob.
if segment_type == SegmentType.FILE:
if isinstance(value, File):
return build_segment_with_type(segment_type, value)
elif isinstance(value, dict):
file = self._rebuild_file_types(value)
return build_segment_with_type(segment_type, file)
else:
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
if segment_type == SegmentType.ARRAY_FILE:
if not isinstance(value, list):
raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
file_list = self._rebuild_file_types(value)
return build_segment_with_type(segment_type=segment_type, value=file_list)
return build_segment_with_type(segment_type=segment_type, value=value)
match segment_type:
case SegmentType.FILE:
if isinstance(value, File):
return build_segment_with_type(segment_type, value)
elif isinstance(value, dict):
file = self._rebuild_file_types(value)
return build_segment_with_type(segment_type, file)
else:
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
case SegmentType.ARRAY_FILE:
if not isinstance(value, list):
raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
file_list = self._rebuild_file_types(value)
return build_segment_with_type(segment_type=segment_type, value=file_list)
case _:
return build_segment_with_type(segment_type=segment_type, value=value)
@staticmethod
def rebuild_file_types(value: Any):
@ -1672,21 +1673,22 @@ class WorkflowDraftVariable(Base):
# Extends `variable_factory.build_segment_with_type` functionality by
# reconstructing `FileSegment`` or `ArrayFileSegment`` objects from
# their serialized dictionary or list representations, respectively.
if segment_type == SegmentType.FILE:
if isinstance(value, File):
return build_segment_with_type(segment_type, value)
elif isinstance(value, dict):
file = cls.rebuild_file_types(value)
return build_segment_with_type(segment_type, file)
else:
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
if segment_type == SegmentType.ARRAY_FILE:
if not isinstance(value, list):
raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
file_list = cls.rebuild_file_types(value)
return build_segment_with_type(segment_type=segment_type, value=file_list)
return build_segment_with_type(segment_type=segment_type, value=value)
match segment_type:
case SegmentType.FILE:
if isinstance(value, File):
return build_segment_with_type(segment_type, value)
elif isinstance(value, dict):
file = cls.rebuild_file_types(value)
return build_segment_with_type(segment_type, file)
else:
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
case SegmentType.ARRAY_FILE:
if not isinstance(value, list):
raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
file_list = cls.rebuild_file_types(value)
return build_segment_with_type(segment_type=segment_type, value=file_list)
case _:
return build_segment_with_type(segment_type=segment_type, value=value)
def get_value(self) -> Segment:
"""Decode the serialized value into its corresponding `Segment` object.

View File

@ -40,7 +40,7 @@ dependencies = [
"numpy~=1.26.4",
"openpyxl~=3.1.5",
"opik~=1.10.37",
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
"litellm==1.83.0", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.40.0",
"opentelemetry-distro==0.61b0",
"opentelemetry-exporter-otlp==1.40.0",
@ -171,7 +171,7 @@ dev = [
"sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.59.1",
"pyrefly>=0.60.0",
]
############################################################

View File

@ -4,6 +4,7 @@ import time
from collections.abc import Sequence
import click
from sqlalchemy import delete, select
from sqlalchemy.orm import Session, sessionmaker
import app
@ -113,11 +114,9 @@ def _delete_batch(
try:
with session.begin_nested():
workflow_run_ids = [run.id for run in workflow_runs]
message_data = (
session.query(Message.id, Message.conversation_id)
.where(Message.workflow_run_id.in_(workflow_run_ids))
.all()
)
message_data = session.execute(
select(Message.id, Message.conversation_id).where(Message.workflow_run_id.in_(workflow_run_ids))
).all()
message_id_list = [msg.id for msg in message_data]
conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id})
if message_id_list:
@ -132,23 +131,19 @@ def _delete_batch(
SavedMessage,
]
for model in message_related_models:
session.query(model).where(model.message_id.in_(message_id_list)).delete(synchronize_session=False) # type: ignore
session.execute(delete(model).where(model.message_id.in_(message_id_list))) # type: ignore
# error: "DeclarativeAttributeIntercept" has no attribute "message_id". But this type is only in lib
# and these 6 types all have the message_id field.
session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete(
synchronize_session=False
)
session.execute(delete(Message).where(Message.workflow_run_id.in_(workflow_run_ids)))
if conversation_id_list:
session.query(ConversationVariable).where(
ConversationVariable.conversation_id.in_(conversation_id_list)
).delete(synchronize_session=False)
session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete(
synchronize_session=False
session.execute(
delete(ConversationVariable).where(ConversationVariable.conversation_id.in_(conversation_id_list))
)
session.execute(delete(Conversation).where(Conversation.id.in_(conversation_id_list)))
def _delete_node_executions(active_session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
run_ids = [run.id for run in runs]
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(

View File

@ -3,7 +3,6 @@ import hashlib
import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import cast
from urllib.parse import urlparse
from uuid import uuid4
@ -19,7 +18,7 @@ from graphon.nodes.question_classifier.entities import QuestionClassifierNodeDat
from graphon.nodes.tool.entities import ToolNodeData
from packaging import version
from packaging.version import parse as parse_version
from pydantic import BaseModel, Field
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -40,6 +39,7 @@ from libs.datetime_utils import naive_utc_now
from models import Account, App, AppMode
from models.model import AppModelConfig, AppModelConfigDict, IconType
from models.workflow import Workflow
from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.workflow_draft_variable_service import WorkflowDraftVariableService
from services.workflow_service import WorkflowService
@ -53,18 +53,6 @@ DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.6.0"
class ImportMode(StrEnum):
YAML_CONTENT = "yaml-content"
YAML_URL = "yaml-url"
class ImportStatus(StrEnum):
COMPLETED = "completed"
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
PENDING = "pending"
FAILED = "failed"
class Import(BaseModel):
id: str
status: ImportStatus
@ -75,10 +63,6 @@ class Import(BaseModel):
error: str = ""
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison"""
try:

View File

@ -120,7 +120,7 @@ class ClearFreePlanTenantExpiredLogs:
apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
app_ids = [app.id for app in apps]
while True:
with Session(db.engine).no_autoflush as session:
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
messages = (
session.query(Message)
.where(
@ -152,7 +152,6 @@ class ClearFreePlanTenantExpiredLogs:
).delete(synchronize_session=False)
cls._clear_message_related_tables(session, tenant_id, message_ids)
session.commit()
click.echo(
click.style(
@ -161,7 +160,7 @@ class ClearFreePlanTenantExpiredLogs:
)
while True:
with Session(db.engine).no_autoflush as session:
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
conversations = (
session.query(Conversation)
.where(
@ -190,7 +189,6 @@ class ClearFreePlanTenantExpiredLogs:
session.query(Conversation).where(
Conversation.id.in_(conversation_ids),
).delete(synchronize_session=False)
session.commit()
click.echo(
click.style(
@ -294,7 +292,7 @@ class ClearFreePlanTenantExpiredLogs:
break
while True:
with Session(db.engine).no_autoflush as session:
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
workflow_app_logs = (
session.query(WorkflowAppLog)
.where(
@ -326,7 +324,6 @@ class ClearFreePlanTenantExpiredLogs:
session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
synchronize_session=False
)
session.commit()
click.echo(
click.style(

View File

@ -5,7 +5,7 @@ from typing import Any
from graphon.model_runtime.entities.provider_entities import FormType
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
@ -53,13 +53,12 @@ class DatasourceProviderService:
"""
remove oauth custom client params
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
session.query(DatasourceOauthTenantParamConfig).filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
).delete()
session.commit()
def decrypt_datasource_provider_credentials(
self,
@ -109,7 +108,7 @@ class DatasourceProviderService:
"""
get credential by id
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
if credential_id:
datasource_provider = (
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
@ -156,7 +155,6 @@ class DatasourceProviderService:
datasource_provider=datasource_provider,
)
datasource_provider.expires_at = refreshed_credentials.expires_at
session.commit()
return self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
@ -174,7 +172,7 @@ class DatasourceProviderService:
"""
get all datasource credentials by provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
datasource_providers = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
@ -224,7 +222,6 @@ class DatasourceProviderService:
provider=provider,
)
real_credentials_list.append(real_credentials)
session.commit()
return real_credentials_list
@ -234,7 +231,7 @@ class DatasourceProviderService:
"""
update datasource provider name
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
target_provider = (
session.query(DatasourceProvider)
.filter_by(
@ -266,7 +263,6 @@ class DatasourceProviderService:
raise ValueError("Authorization name is already exists")
target_provider.name = name
session.commit()
return
def set_default_datasource_provider(
@ -275,7 +271,7 @@ class DatasourceProviderService:
"""
set default datasource provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
# get provider
target_provider = (
session.query(DatasourceProvider)
@ -300,7 +296,6 @@ class DatasourceProviderService:
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
def setup_oauth_custom_client_params(
@ -315,7 +310,7 @@ class DatasourceProviderService:
"""
if client_params is None and enabled is None:
return
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
@ -349,7 +344,6 @@ class DatasourceProviderService:
if enabled is not None:
tenant_oauth_client_params.enabled = enabled
session.commit()
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
"""
@ -488,7 +482,7 @@ class DatasourceProviderService:
"""
update datasource oauth provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
with redis_client.lock(lock, timeout=20):
target_provider = (
@ -535,7 +529,6 @@ class DatasourceProviderService:
target_provider.expires_at = expire_at
target_provider.encrypted_credentials = credentials
target_provider.avatar_url = avatar_url or target_provider.avatar_url
session.commit()
def add_datasource_oauth_provider(
self,
@ -550,7 +543,7 @@ class DatasourceProviderService:
add datasource oauth provider
"""
credential_type = CredentialType.OAUTH2
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
with redis_client.lock(lock, timeout=60):
db_provider_name = name
@ -604,7 +597,6 @@ class DatasourceProviderService:
expires_at=expire_at,
)
session.add(datasource_provider)
session.commit()
def add_datasource_api_key_provider(
self,
@ -623,7 +615,7 @@ class DatasourceProviderService:
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
with redis_client.lock(lock, timeout=20):
db_provider_name = name or self.generate_next_datasource_provider_name(
@ -670,7 +662,6 @@ class DatasourceProviderService:
encrypted_credentials=credentials,
)
session.add(datasource_provider)
session.commit()
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
"""
@ -926,7 +917,7 @@ class DatasourceProviderService:
update datasource credentials.
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
@ -980,7 +971,6 @@ class DatasourceProviderService:
encrypted_credentials[key] = value
datasource_provider.encrypted_credentials = encrypted_credentials
session.commit()
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
"""

View File

@ -2,7 +2,7 @@ import logging
from collections.abc import Mapping
from sqlalchemy import case, select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
@ -24,7 +24,7 @@ class EndUserService:
when an end-user ID is known.
"""
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
return session.scalar(
select(EndUser)
.where(
@ -54,7 +54,7 @@ class EndUserService:
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
# Query with ORDER BY to prioritize exact type matches while maintaining backward compatibility
# This single query approach is more efficient than separate queries
end_user = session.scalar(
@ -82,7 +82,6 @@ class EndUserService:
user_id,
)
end_user.type = type
session.commit()
else:
# Create new end user if none exists
end_user = EndUser(
@ -94,7 +93,6 @@ class EndUserService:
external_user_id=user_id,
)
session.add(end_user)
session.commit()
return end_user
@ -135,7 +133,7 @@ class EndUserService:
if not unique_app_ids:
return result
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
# Fetch existing end users for all target apps in a single query
existing_end_users: list[EndUser] = list(
session.scalars(
@ -174,7 +172,6 @@ class EndUserService:
)
session.add_all(new_end_users)
session.commit()
for eu in new_end_users:
result[eu.app_id] = eu

View File

@ -0,0 +1,21 @@
from enum import StrEnum
from pydantic import BaseModel, Field
from core.plugin.entities.plugin import PluginDependency
class ImportMode(StrEnum):
YAML_CONTENT = "yaml-content"
YAML_URL = "yaml-url"
class ImportStatus(StrEnum):
COMPLETED = "completed"
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
PENDING = "pending"
FAILED = "failed"
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)

View File

@ -1,4 +1,5 @@
from sqlalchemy.orm import Session
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from models.account import TenantPluginAutoUpgradeStrategy
@ -7,11 +8,11 @@ from models.account import TenantPluginAutoUpgradeStrategy
class PluginAutoUpgradeService:
@staticmethod
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
with Session(db.engine) as session:
return (
session.query(TenantPluginAutoUpgradeStrategy)
with sessionmaker(bind=db.engine).begin() as session:
return session.scalar(
select(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
.limit(1)
)
@staticmethod
@ -23,11 +24,11 @@ class PluginAutoUpgradeService:
exclude_plugins: list[str],
include_plugins: list[str],
) -> bool:
with Session(db.engine) as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
with sessionmaker(bind=db.engine).begin() as session:
exist_strategy = session.scalar(
select(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
.limit(1)
)
if not exist_strategy:
strategy = TenantPluginAutoUpgradeStrategy(
@ -46,16 +47,15 @@ class PluginAutoUpgradeService:
exist_strategy.exclude_plugins = exclude_plugins
exist_strategy.include_plugins = include_plugins
session.commit()
return True
@staticmethod
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
with Session(db.engine) as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
with sessionmaker(bind=db.engine).begin() as session:
exist_strategy = session.scalar(
select(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
.limit(1)
)
if not exist_strategy:
# create for this tenant
@ -83,5 +83,4 @@ class PluginAutoUpgradeService:
exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
exist_strategy.exclude_plugins = [plugin_id]
session.commit()
return True

View File

@ -1,4 +1,5 @@
from sqlalchemy.orm import Session
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from models.account import TenantPluginPermission
@ -7,8 +8,10 @@ from models.account import TenantPluginPermission
class PluginPermissionService:
@staticmethod
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
with Session(db.engine) as session:
return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
with sessionmaker(bind=db.engine).begin() as session:
return session.scalar(
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
)
@staticmethod
def change_permission(
@ -16,9 +19,9 @@ class PluginPermissionService:
install_permission: TenantPluginPermission.InstallPermission,
debug_permission: TenantPluginPermission.DebugPermission,
):
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
with sessionmaker(bind=db.engine).begin() as session:
permission = session.scalar(
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
)
if not permission:
permission = TenantPluginPermission(
@ -30,5 +33,4 @@ class PluginPermissionService:
permission.install_permission = install_permission
permission.debug_permission = debug_permission
session.commit()
return True

View File

@ -555,7 +555,7 @@ class RagPipelineService:
workflow_node_execution.id
)
with Session(bind=db.engine) as session, session.begin():
with sessionmaker(bind=db.engine).begin() as session:
draft_var_saver = DraftVariableSaver(
session=session,
app_id=pipeline.id,
@ -569,7 +569,6 @@ class RagPipelineService:
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
)
session.commit()
if isinstance(workflow_node_execution_db_model, WorkflowNodeExecutionModel):
enqueue_draft_node_execution_trace(
execution=workflow_node_execution_db_model,
@ -1325,7 +1324,7 @@ class RagPipelineService:
# Convert node_execution to WorkflowNodeExecution after save
workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) # type: ignore
with Session(bind=db.engine) as session, session.begin():
with sessionmaker(bind=db.engine).begin() as session:
draft_var_saver = DraftVariableSaver(
session=session,
app_id=pipeline.id,
@ -1339,7 +1338,6 @@ class RagPipelineService:
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
)
session.commit()
enqueue_draft_node_execution_trace(
execution=workflow_node_execution_db_model,
outputs=workflow_node_execution.outputs,

View File

@ -5,7 +5,6 @@ import logging
import uuid
from collections.abc import Mapping
from datetime import UTC, datetime
from enum import StrEnum
from typing import cast
from urllib.parse import urlparse
from uuid import uuid4
@ -21,7 +20,7 @@ from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeDat
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
from graphon.nodes.tool.entities import ToolNodeData
from packaging import version
from pydantic import BaseModel, Field
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -38,6 +37,7 @@ from models import Account
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
from models.enums import CollectionBindingType, DatasetRuntimeMode
from models.workflow import Workflow, WorkflowType
from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus
from services.entities.knowledge_entities.rag_pipeline_entities import (
IconInfo,
KnowledgeConfiguration,
@ -54,18 +54,6 @@ DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.1.0"
class ImportMode(StrEnum):
YAML_CONTENT = "yaml-content"
YAML_URL = "yaml-url"
class ImportStatus(StrEnum):
COMPLETED = "completed"
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
PENDING = "pending"
FAILED = "failed"
class RagPipelineImportInfo(BaseModel):
id: str
status: ImportStatus
@ -76,10 +64,6 @@ class RagPipelineImportInfo(BaseModel):
dataset_id: str | None = None
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison"""
try:

View File

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, TypedDict, cast
import sqlalchemy as sa
from sqlalchemy import delete, select, tuple_
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from extensions.ext_database import db
@ -369,7 +369,7 @@ class MessagesCleanService:
batch_deleted_messages = 0
# Step 1: Fetch a batch of messages using cursor
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
fetch_messages_start = time.monotonic()
msg_stmt = (
select(Message.id, Message.app_id, Message.created_at)
@ -477,7 +477,7 @@ class MessagesCleanService:
# Step 4: Batch delete messages and their relations
if not self._dry_run:
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
delete_relations_start = time.monotonic()
# Delete related records first
self._batch_delete_message_relations(session, message_ids_to_delete)
@ -489,9 +489,7 @@ class MessagesCleanService:
delete_result = cast(CursorResult, session.execute(delete_stmt))
messages_deleted = delete_result.rowcount
delete_messages_ms = int((time.monotonic() - delete_messages_start) * 1000)
commit_start = time.monotonic()
session.commit()
commit_ms = int((time.monotonic() - commit_start) * 1000)
commit_ms = 0
stats["total_deleted"] += messages_deleted
batch_deleted_messages = messages_deleted

View File

@ -5,7 +5,7 @@ from pathlib import Path
from typing import Any
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
@ -46,13 +46,12 @@ class BuiltinToolManageService:
delete custom oauth client params
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
session.query(ToolOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
).delete()
session.commit()
return {"result": "success"}
@staticmethod
@ -150,7 +149,7 @@ class BuiltinToolManageService:
"""
update builtin tool provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
# get if the provider exists
db_provider = (
session.query(BuiltinToolProvider)
@ -203,9 +202,7 @@ class BuiltinToolManageService:
db_provider.name = name
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@ -222,7 +219,7 @@ class BuiltinToolManageService:
"""
add builtin tool provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
try:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
@ -281,9 +278,7 @@ class BuiltinToolManageService:
)
session.add(db_provider)
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@ -379,7 +374,7 @@ class BuiltinToolManageService:
"""
delete tool provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
db_provider = (
session.query(BuiltinToolProvider)
.where(
@ -393,7 +388,6 @@ class BuiltinToolManageService:
raise ValueError(f"you have not added provider {provider}")
session.delete(db_provider)
session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
@ -409,7 +403,7 @@ class BuiltinToolManageService:
"""
set default provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
# get provider
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
if target_provider is None:
@ -422,7 +416,6 @@ class BuiltinToolManageService:
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
@ -654,7 +647,7 @@ class BuiltinToolManageService:
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
custom_client_params = (
session.query(ToolOAuthTenantClient)
.filter_by(
@ -690,7 +683,6 @@ class BuiltinToolManageService:
if enable_oauth_custom_client is not None:
custom_client_params.enabled = enable_oauth_custom_client
session.commit()
return {"result": "success"}
@staticmethod

View File

@ -48,21 +48,25 @@ class ToolTransformService:
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
)
if provider_type == ToolProviderType.BUILT_IN:
return str(url_prefix / "builtin" / provider_name / "icon")
elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}:
try:
if isinstance(icon, str):
parsed = emoji_icon_adapter.validate_json(icon)
return {"background": parsed["background"], "content": parsed["content"]}
return {"background": icon["background"], "content": icon["content"]}
except (ValueError, ValidationError, KeyError):
return {"background": "#252525", "content": "\ud83d\ude01"}
elif provider_type == ToolProviderType.MCP:
if isinstance(icon, Mapping):
return {"background": icon.get("background", ""), "content": icon.get("content", "")}
return icon
return ""
match provider_type:
case ToolProviderType.BUILT_IN:
return str(url_prefix / "builtin" / provider_name / "icon")
case ToolProviderType.API | ToolProviderType.WORKFLOW:
try:
if isinstance(icon, str):
parsed = emoji_icon_adapter.validate_json(icon)
return {"background": parsed["background"], "content": parsed["content"]}
return {"background": icon["background"], "content": icon["content"]}
except (ValueError, ValidationError, KeyError):
return {"background": "#252525", "content": "\ud83d\ude01"}
case ToolProviderType.MCP:
if isinstance(icon, Mapping):
return {"background": icon.get("background", ""), "content": icon.get("content", "")}
return icon
case ToolProviderType.PLUGIN | ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL:
return ""
case _:
return ""
@staticmethod
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):

View File

@ -8,7 +8,7 @@ This service centralizes all AppTrigger-related business logic.
import logging
from sqlalchemy import update
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from models.enums import AppTriggerStatus
@ -34,13 +34,12 @@ class AppTriggerService:
"""
try:
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
session.execute(
update(AppTrigger)
.where(AppTrigger.tenant_id == tenant_id, AppTrigger.status == AppTriggerStatus.ENABLED)
.values(status=AppTriggerStatus.RATE_LIMITED)
)
session.commit()
logger.info("Marked all enabled triggers as rate limited for tenant %s", tenant_id)
except Exception:
logger.exception("Failed to mark all enabled triggers as rate limited for tenant %s", tenant_id)

View File

@ -6,7 +6,7 @@ from collections.abc import Mapping
from typing import Any
from sqlalchemy import desc, func
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
@ -146,7 +146,7 @@ class TriggerProviderService:
"""
try:
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
# Use distributed lock to prevent race conditions
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
with redis_client.lock(lock_key, timeout=20):
@ -205,7 +205,6 @@ class TriggerProviderService:
subscription.id = subscription_id or str(uuid.uuid4())
session.add(subscription)
session.commit()
return {
"result": "success",
@ -241,7 +240,7 @@ class TriggerProviderService:
:param expires_at: Optional new expiration timestamp
:return: Success response with updated subscription info
"""
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
# Use distributed lock to prevent race conditions on the same subscription
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
with redis_client.lock(lock_key, timeout=20):
@ -302,8 +301,6 @@ class TriggerProviderService:
if expires_at is not None:
subscription.expires_at = expires_at
session.commit()
# Clear subscription cache
delete_cache_for_subscription(
tenant_id=tenant_id,
@ -404,7 +401,7 @@ class TriggerProviderService:
:param subscription_id: Subscription instance ID
:return: New token info
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
if not subscription:
@ -448,7 +445,6 @@ class TriggerProviderService:
# Update credentials
subscription.credentials = dict(encrypter.encrypt(dict(refreshed_credentials.credentials)))
subscription.credential_expires_at = refreshed_credentials.expires_at
session.commit()
# Clear cache
cache.delete()
@ -478,7 +474,7 @@ class TriggerProviderService:
"""
now_ts: int = int(now if now is not None else _time.time())
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
@ -531,7 +527,6 @@ class TriggerProviderService:
# Persist refreshed properties and expires_at
subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties)))
subscription.expires_at = int(refreshed.expires_at)
session.commit()
properties_cache.delete()
logger.info(
@ -639,7 +634,7 @@ class TriggerProviderService:
tenant_id=tenant_id, provider_id=provider_id
)
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
# Find existing custom client params
custom_client = (
session.query(TriggerOAuthTenantClient)
@ -683,8 +678,6 @@ class TriggerProviderService:
if enabled is not None:
custom_client.enabled = enabled
session.commit()
return {"result": "success"}
@classmethod
@ -733,13 +726,12 @@ class TriggerProviderService:
:param provider_id: Provider identifier
:return: Success response
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
session.query(TriggerOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
).delete()
session.commit()
return {"result": "success"}

View File

@ -8,7 +8,7 @@ from flask import Request, Response
from graphon.entities.graph_config import NodeConfigDict
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerDispatchResponse, TriggerInvokeEventResponse
@ -215,7 +215,7 @@ class TriggerService:
not_found_in_cache.append(node_info)
continue
with Session(db.engine) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
try:
# lock the concurrent plugin trigger creation
redis_client.lock(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
@ -260,7 +260,6 @@ class TriggerService:
cache.model_dump_json(),
ex=60 * 60,
)
session.commit()
# Update existing records if subscription_id changed
for node_info in nodes_in_graph:
@ -290,14 +289,12 @@ class TriggerService:
cache.model_dump_json(),
ex=60 * 60,
)
session.commit()
# delete the nodes not found in the graph
for node_id in nodes_id_in_db:
if node_id not in nodes_id_in_graph:
session.delete(nodes_id_in_db[node_id])
redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{app.id}:{node_id}")
session.commit()
except Exception:
logger.exception("Failed to sync plugin trigger relationships for app %s", app.id)
raise

View File

@ -12,7 +12,7 @@ from graphon.file import FileTransferMethod
from graphon.variables.types import ArrayValidation, SegmentType
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import RequestEntityTooLarge
@ -598,21 +598,38 @@ class WebhookService:
Raises:
ValueError: If the value cannot be converted to the specified type
"""
if param_type == SegmentType.STRING:
return value
elif param_type == SegmentType.NUMBER:
if not cls._can_convert_to_number(value):
raise ValueError(f"Cannot convert '{value}' to number")
numeric_value = float(value)
return int(numeric_value) if numeric_value.is_integer() else numeric_value
elif param_type == SegmentType.BOOLEAN:
lower_value = value.lower()
bool_map = {"true": True, "false": False, "1": True, "0": False, "yes": True, "no": False}
if lower_value not in bool_map:
raise ValueError(f"Cannot convert '{value}' to boolean")
return bool_map[lower_value]
else:
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
match param_type:
case SegmentType.STRING:
return value
case SegmentType.NUMBER:
if not cls._can_convert_to_number(value):
raise ValueError(f"Cannot convert '{value}' to number")
numeric_value = float(value)
return int(numeric_value) if numeric_value.is_integer() else numeric_value
case SegmentType.BOOLEAN:
lower_value = value.lower()
bool_map = {"true": True, "false": False, "1": True, "0": False, "yes": True, "no": False}
if lower_value not in bool_map:
raise ValueError(f"Cannot convert '{value}' to boolean")
return bool_map[lower_value]
case (
SegmentType.OBJECT
| SegmentType.FILE
| SegmentType.ARRAY_ANY
| SegmentType.ARRAY_STRING
| SegmentType.ARRAY_NUMBER
| SegmentType.ARRAY_OBJECT
| SegmentType.ARRAY_FILE
| SegmentType.ARRAY_BOOLEAN
| SegmentType.SECRET
| SegmentType.INTEGER
| SegmentType.FLOAT
| SegmentType.NONE
| SegmentType.GROUP
):
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
case _:
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
@classmethod
def _validate_json_value(cls, param_name: str, value: Any, param_type: SegmentType | str) -> Any:
@ -918,7 +935,7 @@ class WebhookService:
logger.warning("Failed to acquire lock for webhook sync, app %s", app.id)
raise RuntimeError("Failed to acquire lock for webhook trigger synchronization")
with Session(db.engine) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
# fetch the non-cached nodes from DB
all_records = session.scalars(
select(WorkflowWebhookTrigger).where(
@ -947,14 +964,12 @@ class WebhookService:
redis_client.set(
f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}", cache.model_dump_json(), ex=60 * 60
)
session.commit()
# delete the nodes not found in the graph
for node_id in nodes_id_in_db:
if node_id not in nodes_id_in_graph:
session.delete(nodes_id_in_db[node_id])
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
session.commit()
except Exception:
logger.exception("Failed to sync webhook relationships for app %s", app.id)
raise

View File

@ -1075,9 +1075,8 @@ class DraftVariableSaver:
)
engine = bind = self._session.get_bind()
assert isinstance(engine, Engine)
with Session(bind=engine, expire_on_commit=False) as session:
with sessionmaker(bind=engine, expire_on_commit=False).begin() as session:
session.add(variable_file)
session.commit()
return truncation_result.result, variable_file

View File

@ -5,7 +5,7 @@ import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, cast
from graphon.entities import GraphInitParams, WorkflowNodeExecution
from graphon.entities import WorkflowNodeExecution
from graphon.entities.graph_config import NodeConfigDict
from graphon.entities.pause_reason import HumanInputRequired
from graphon.enums import (
@ -48,7 +48,12 @@ from core.workflow.human_input_compat import (
normalize_human_input_node_data_for_graph,
parse_human_input_delivery_methods,
)
from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type
from core.workflow.node_factory import (
LATEST_VERSION,
DifyGraphInitContext,
get_node_type_classes_mapping,
is_start_node_type,
)
from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
@ -837,7 +842,7 @@ class WorkflowService:
with sessionmaker(db.engine).begin() as session:
outputs = workflow_node_execution.load_full_outputs(session, storage)
with Session(bind=db.engine) as session, session.begin():
with sessionmaker(bind=db.engine).begin() as session:
draft_var_saver = DraftVariableSaver(
session=session,
app_id=app_model.id,
@ -848,7 +853,6 @@ class WorkflowService:
user=account,
)
draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs)
session.commit()
enqueue_draft_node_execution_trace(
execution=workflow_node_execution,
@ -977,7 +981,7 @@ class WorkflowService:
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None
with Session(bind=db.engine) as session, session.begin():
with sessionmaker(bind=db.engine).begin() as session:
draft_var_saver = DraftVariableSaver(
session=session,
app_id=app_model.id,
@ -988,7 +992,6 @@ class WorkflowService:
enclosing_node_id=enclosing_node_id,
)
draft_var_saver.save(outputs=outputs, process_data={})
session.commit()
return outputs
@ -1134,18 +1137,20 @@ class WorkflowService:
node_config: NodeConfigDict,
variable_pool: VariablePool,
) -> HumanInputNode:
graph_init_params = GraphInitParams(
run_context = build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
user_id=account.id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow.id,
graph_config=workflow.graph_dict,
run_context=build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
user_id=account.id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
run_context=run_context,
call_depth=0,
)
graph_init_params = graph_init_context.to_graph_init_params()
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
@ -1155,7 +1160,7 @@ class WorkflowService:
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
runtime=DifyHumanInputNodeRuntime(run_context),
)
return node

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import delete, select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
@ -30,7 +31,9 @@ def add_document_to_index_task(dataset_document_id: str):
start_at = time.perf_counter()
with session_factory.create_session() as session:
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
dataset_document = session.scalar(
select(DatasetDocument).where(DatasetDocument.id == dataset_document_id).limit(1)
)
if not dataset_document:
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
return
@ -45,15 +48,14 @@ def add_document_to_index_task(dataset_document_id: str):
if not dataset:
raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == SegmentStatus.COMPLETED,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
documents = []
multimodal_documents = []
@ -104,18 +106,15 @@ def add_document_to_index_task(dataset_document_id: str):
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
# delete auto disable log
session.query(DatasetAutoDisableLog).where(
DatasetAutoDisableLog.document_id == dataset_document.id
).delete()
session.execute(
delete(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id)
)
# update segment to enable
session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
{
DocumentSegment.enabled: True,
DocumentSegment.disabled_at: None,
DocumentSegment.disabled_by: None,
DocumentSegment.updated_at: naive_utc_now(),
}
session.execute(
update(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id)
.values(enabled=True, disabled_at=None, disabled_by=None, updated_at=naive_utc_now())
)
session.commit()

View File

@ -1,9 +1,11 @@
import logging
import time
from typing import cast
import click
from celery import shared_task
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -92,14 +94,16 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
# ============ Step 3: Delete metadata binding (separate short transaction) ============
try:
with session_factory.create_session() as session:
deleted_count = int(
session.query(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
)
.delete(synchronize_session=False)
result = cast(
CursorResult,
session.execute(
delete(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
)
),
)
deleted_count = result.rowcount
session.commit()
logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id)
except Exception:

View File

@ -112,7 +112,9 @@ def clean_dataset_task(
segment_ids = [segment.id for segment in segments]
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
).all()
for image_file in image_files:
if image_file is None:
continue
@ -150,20 +152,22 @@ def clean_dataset_task(
)
session.execute(binding_delete_stmt)
session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
session.execute(delete(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id))
session.execute(delete(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id))
session.execute(delete(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id))
# delete dataset metadata
session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
session.execute(delete(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id))
session.execute(delete(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id))
# delete pipeline and workflow
if pipeline_id:
session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
session.query(Workflow).where(
Workflow.tenant_id == tenant_id,
Workflow.app_id == pipeline_id,
Workflow.type == WorkflowType.RAG_PIPELINE,
).delete()
session.execute(delete(Pipeline).where(Pipeline.id == pipeline_id))
session.execute(
delete(Workflow).where(
Workflow.tenant_id == tenant_id,
Workflow.app_id == pipeline_id,
Workflow.type == WorkflowType.RAG_PIPELINE,
)
)
# delete files
if documents:
file_ids = []
@ -174,7 +178,7 @@ def clean_dataset_task(
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
file_ids.append(file_id)
files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
storage.delete(file.key)

View File

@ -32,7 +32,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise Exception("Document has no dataset")
@ -63,7 +63,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
if index_node_ids:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if dataset:
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
@ -94,7 +94,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
with session_factory.create_session() as session, session.begin():
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if file:
try:
storage.delete(file.key)
@ -124,10 +124,12 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
with session_factory.create_session() as session, session.begin():
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
session.execute(
delete(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
)
)
end_at = time.perf_counter()
logger.info(

View File

@ -3,7 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
@ -29,7 +29,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise Exception("Dataset not found")
@ -49,23 +49,24 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
)
session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if segments:
documents = []
for segment in segments:
@ -82,13 +83,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
)
session.commit()
elif action == "update":
@ -104,8 +109,10 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
)
session.commit()
@ -115,15 +122,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if segments:
documents = []
multimodal_documents = []
@ -172,13 +178,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
)
session.commit()
else:

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import delete
from core.db.session_factory import session_factory
from models import ConversationVariable
@ -29,29 +30,21 @@ def delete_conversation_related_data(conversation_id: str):
with session_factory.create_session() as session:
try:
session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
synchronize_session=False
session.execute(delete(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id))
session.execute(delete(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id))
session.execute(
delete(ToolConversationVariables).where(ToolConversationVariables.conversation_id == conversation_id)
)
session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
synchronize_session=False
)
session.execute(delete(ToolFile).where(ToolFile.conversation_id == conversation_id))
session.query(ToolConversationVariables).where(
ToolConversationVariables.conversation_id == conversation_id
).delete(synchronize_session=False)
session.execute(delete(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id))
session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
session.execute(delete(Message).where(Message.conversation_id == conversation_id))
session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
synchronize_session=False
)
session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
)
session.execute(delete(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id))
session.commit()

View File

@ -3,7 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import delete
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -29,12 +29,12 @@ def delete_segment_from_index_task(
start_at = time.perf_counter()
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
return
dataset_document = session.query(Document).where(Document.id == document_id).first()
dataset_document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if not dataset_document:
return
@ -60,11 +60,9 @@ def delete_segment_from_index_task(
)
if dataset.is_multimodal:
# delete segment attachment binding
segment_attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
)
segment_attachment_bindings = session.scalars(
select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
).all()
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
@ -77,7 +75,7 @@ def delete_segment_from_index_task(
session.execute(segment_attachment_bind_delete_stmt)
# delete upload file
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
session.execute(delete(UploadFile).where(UploadFile.id.in_(attachment_ids)))
session.commit()
end_at = time.perf_counter()

View File

@ -3,7 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -27,12 +27,12 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
start_at = time.perf_counter()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
return
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document = session.scalar(select(DatasetDocument).where(DatasetDocument.id == document_id).limit(1))
if not dataset_document:
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
@ -58,11 +58,9 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
index_node_ids = [segment.index_node_id for segment in segments]
if dataset.is_multimodal:
segment_ids = [segment.id for segment in segments]
segment_attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
)
segment_attachment_bindings = session.scalars(
select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
).all()
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_node_ids.extend(attachment_ids)
@ -87,16 +85,14 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
except Exception:
# update segment error msg
session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"disabled_at": None,
"disabled_by": None,
"enabled": True,
}
session.execute(
update(DocumentSegment)
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.values(disabled_at=None, disabled_by=None, enabled=True)
)
session.commit()
finally:

View File

@ -32,7 +32,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
tenant_id = None
with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
@ -42,7 +44,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
return
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise Exception("Dataset not found")
@ -87,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
@ -112,7 +114,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
try:
index_processor = IndexProcessorFactory(index_type).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if dataset:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
@ -120,7 +122,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.exception("Failed to clean vector index for document %s", document_id)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if not document:
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
return
@ -140,7 +142,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
try:
indexing_runner = IndexingRunner()
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if document:
indexing_runner.run([document])
end_at = time.perf_counter()
@ -150,7 +152,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
except Exception as e:
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)

View File

@ -47,7 +47,7 @@ def regenerate_summary_index_task(
try:
with session_factory.create_session() as session:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return
@ -84,8 +84,8 @@ def regenerate_summary_index_task(
# For embedding_model change: directly query all segments with existing summaries
# Don't require document indexing_status == "completed"
# Include summaries with status "completed" or "error" (if they have content)
segments_with_summaries = (
session.query(DocumentSegment, DocumentSegmentSummary)
segments_with_summaries = session.execute(
select(DocumentSegment, DocumentSegmentSummary)
.join(
DocumentSegmentSummary,
DocumentSegment.id == DocumentSegmentSummary.chunk_id,
@ -110,8 +110,7 @@ def regenerate_summary_index_task(
DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
.all()
)
).all()
if not segments_with_summaries:
logger.info(
@ -215,8 +214,8 @@ def regenerate_summary_index_task(
try:
# Get all segments with existing summaries
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.join(
DocumentSegmentSummary,
DocumentSegment.id == DocumentSegmentSummary.chunk_id,
@ -229,8 +228,7 @@ def regenerate_summary_index_task(
DocumentSegmentSummary.dataset_id == dataset_id,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if not segments:
continue
@ -245,13 +243,13 @@ def regenerate_summary_index_task(
summary_record = None
try:
# Get existing summary record
summary_record = (
session.query(DocumentSegmentSummary)
.filter_by(
chunk_id=segment.id,
dataset_id=dataset_id,
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset_id,
)
.first()
.limit(1)
)
if not summary_record:

View File

@ -0,0 +1,388 @@
from __future__ import annotations
import uuid
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import select
from sqlalchemy.orm import Session
from models.model import AccountTrialAppRecord, TrialApp
from services import recommended_app_service as service_module
from services.recommended_app_service import RecommendedAppService
# ── Helpers ────────────────────────────────────────────────────────────
def _apps_response(
recommended_apps: list[dict] | None = None,
categories: list[str] | None = None,
) -> dict:
if recommended_apps is None:
recommended_apps = [
{"id": "app-1", "name": "Test App 1", "description": "d1", "category": "productivity"},
{"id": "app-2", "name": "Test App 2", "description": "d2", "category": "communication"},
]
if categories is None:
categories = ["productivity", "communication", "utilities"]
return {"recommended_apps": recommended_apps, "categories": categories}
def _app_detail(
app_id: str = "app-123",
name: str = "Test App",
description: str = "Test description",
**kwargs: Any,
) -> dict:
detail: dict[str, Any] = {
"id": app_id,
"name": name,
"description": description,
"category": kwargs.get("category", "productivity"),
"icon": kwargs.get("icon", "🚀"),
"model_config": kwargs.get("model_config", {}),
}
detail.update(kwargs)
return detail
def _recommendation_detail(result: dict[str, Any] | None) -> dict[str, Any] | None:
return cast("dict[str, Any] | None", result)
def _mock_factory_for_apps(
monkeypatch: pytest.MonkeyPatch,
*,
mode: str,
result: dict[str, Any],
fallback_result: dict[str, Any] | None = None,
) -> tuple[MagicMock, MagicMock]:
retrieval_instance = MagicMock()
retrieval_instance.get_recommended_apps_and_categories.return_value = result
retrieval_factory = MagicMock(return_value=retrieval_instance)
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", mode, raising=False)
monkeypatch.setattr(
service_module.RecommendAppRetrievalFactory,
"get_recommend_app_factory",
MagicMock(return_value=retrieval_factory),
)
builtin_instance = MagicMock()
if fallback_result is not None:
builtin_instance.fetch_recommended_apps_from_builtin.return_value = fallback_result
monkeypatch.setattr(
service_module.RecommendAppRetrievalFactory,
"get_buildin_recommend_app_retrieval",
MagicMock(return_value=builtin_instance),
)
return retrieval_instance, builtin_instance
# ── Pure logic tests: get_recommended_apps_and_categories ──────────────
class TestRecommendedAppServiceGetApps:
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_success_with_apps(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
expected = _apps_response()
mock_instance = MagicMock()
mock_instance.get_recommended_apps_and_categories.return_value = expected
mock_factory = MagicMock(return_value=mock_instance)
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
assert result == expected
assert len(result["recommended_apps"]) == 2
assert len(result["categories"]) == 3
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
mock_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
empty_response = {"recommended_apps": [], "categories": []}
builtin_response = _apps_response(
recommended_apps=[{"id": "builtin-1", "name": "Builtin App", "category": "default"}]
)
mock_remote_instance = MagicMock()
mock_remote_instance.get_recommended_apps_and_categories.return_value = empty_response
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_remote_instance)
mock_builtin_instance = MagicMock()
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN")
assert result == builtin_response
assert result["recommended_apps"][0]["id"] == "builtin-1"
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db"
none_response = {"recommended_apps": None, "categories": ["test"]}
builtin_response = _apps_response()
mock_db_instance = MagicMock()
mock_db_instance.get_recommended_apps_and_categories.return_value = none_response
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_db_instance)
mock_builtin_instance = MagicMock()
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
assert result == builtin_response
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_different_languages(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
for language in ["en-US", "zh-CN", "ja-JP", "fr-FR"]:
lang_response = _apps_response(
recommended_apps=[{"id": f"app-{language}", "name": f"App {language}", "category": "test"}]
)
mock_instance = MagicMock()
mock_instance.get_recommended_apps_and_categories.return_value = lang_response
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
result = RecommendedAppService.get_recommended_apps_and_categories(language)
assert result["recommended_apps"][0]["id"] == f"app-{language}"
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_uses_correct_factory_mode(self, mock_config, mock_factory_class):
for mode in ["remote", "builtin", "db"]:
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
response = _apps_response()
mock_instance = MagicMock()
mock_instance.get_recommended_apps_and_categories.return_value = response
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
RecommendedAppService.get_recommended_apps_and_categories("en-US")
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
# ── Pure logic tests: get_recommend_app_detail ─────────────────────────
class TestRecommendedAppServiceGetDetail:
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_success(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
expected = _app_detail(app_id="app-123", name="Productivity App", description="A great app")
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = expected
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("app-123"))
assert result == expected
assert result["id"] == "app-123"
mock_instance.get_recommend_app_detail.assert_called_once_with("app-123")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_different_modes(self, mock_config, mock_factory_class):
for mode in ["remote", "builtin", "db"]:
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
detail = _app_detail(app_id="test-app", name=f"App from {mode}")
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = detail
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("test-app"))
assert result["name"] == f"App from {mode}"
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_returns_none_when_not_found(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = None
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("nonexistent"))
assert result is None
mock_instance.get_recommend_app_detail.assert_called_once_with("nonexistent")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_returns_empty_dict(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = {}
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("app-empty"))
assert result == {}
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_complex_model_config(self, mock_config, mock_factory_class):
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
complex_config = {
"provider": "openai",
"model": "gpt-4",
"parameters": {"temperature": 0.7, "max_tokens": 2000, "top_p": 1.0},
}
expected = _app_detail(
app_id="complex-app",
name="Complex App",
model_config=complex_config,
workflows=["workflow-1", "workflow-2"],
tools=["tool-1", "tool-2", "tool-3"],
)
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = expected
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("complex-app"))
assert result["model_config"] == complex_config
assert len(result["workflows"]) == 2
assert len(result["tools"]) == 3
# ── Integration tests: trial app features (real DB) ────────────────────
class TestRecommendedAppServiceTrialFeatures:
def test_get_apps_should_not_query_trial_table_when_disabled(
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
):
expected = {"recommended_apps": [{"app_id": "app-1"}], "categories": ["all"]}
retrieval_instance, builtin_instance = _mock_factory_for_apps(monkeypatch, mode="remote", result=expected)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",
MagicMock(return_value=SimpleNamespace(enable_trial_app=False)),
)
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
assert result == expected
retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
builtin_instance.fetch_recommended_apps_from_builtin.assert_not_called()
def test_get_apps_should_enrich_can_trial_when_enabled(
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
):
app_id_1 = str(uuid.uuid4())
app_id_2 = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())
# app_id_1 has a TrialApp record; app_id_2 does not
db_session_with_containers.add(TrialApp(app_id=app_id_1, tenant_id=tenant_id))
db_session_with_containers.commit()
remote_result = {"recommended_apps": [], "categories": []}
fallback_result = {
"recommended_apps": [{"app_id": app_id_1}, {"app_id": app_id_2}],
"categories": ["all"],
}
_, builtin_instance = _mock_factory_for_apps(
monkeypatch, mode="remote", result=remote_result, fallback_result=fallback_result
)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
)
result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP")
builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
assert result["recommended_apps"][0]["can_trial"] is True
assert result["recommended_apps"][1]["can_trial"] is False
@pytest.mark.parametrize("has_trial_app", [True, False])
def test_get_detail_should_set_can_trial_when_enabled(
self,
db_session_with_containers: Session,
monkeypatch: pytest.MonkeyPatch,
has_trial_app: bool,
):
app_id = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())
if has_trial_app:
db_session_with_containers.add(TrialApp(app_id=app_id, tenant_id=tenant_id))
db_session_with_containers.commit()
detail = {"id": app_id, "name": "Test App"}
retrieval_instance = MagicMock()
retrieval_instance.get_recommend_app_detail.return_value = detail
retrieval_factory = MagicMock(return_value=retrieval_instance)
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", "remote", raising=False)
monkeypatch.setattr(
service_module.RecommendAppRetrievalFactory,
"get_recommend_app_factory",
MagicMock(return_value=retrieval_factory),
)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
)
result = cast(dict[str, Any], RecommendedAppService.get_recommend_app_detail(app_id))
assert result["id"] == app_id
assert result["can_trial"] is has_trial_app
def test_add_trial_app_record_increments_count_for_existing(self, db_session_with_containers: Session):
app_id = str(uuid.uuid4())
account_id = str(uuid.uuid4())
db_session_with_containers.add(AccountTrialAppRecord(app_id=app_id, account_id=account_id, count=3))
db_session_with_containers.commit()
RecommendedAppService.add_trial_app_record(app_id, account_id)
db_session_with_containers.expire_all()
record = db_session_with_containers.scalar(
select(AccountTrialAppRecord)
.where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
.limit(1)
)
assert record is not None
assert record.count == 4
def test_add_trial_app_record_creates_new_record(self, db_session_with_containers: Session):
app_id = str(uuid.uuid4())
account_id = str(uuid.uuid4())
RecommendedAppService.add_trial_app_record(app_id, account_id)
db_session_with_containers.expire_all()
record = db_session_with_containers.scalar(
select(AccountTrialAppRecord)
.where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
.limit(1)
)
assert record is not None
assert record.app_id == app_id
assert record.account_id == account_id
assert record.count == 1

View File

@ -134,6 +134,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
@ -150,7 +151,9 @@ class TestAdvancedChatAppRunnerConversationVariables:
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_session_class.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
@ -177,7 +180,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Note: Since we're mocking ConversationVariable.from_variable,
# we can't directly check the id, but we can verify add_all was called
assert mock_session.add_all.called, "Session add_all should have been called"
assert mock_session.commit.called, "Session commit should have been called"
def test_no_variables_creates_all(self):
"""Test that all conversation variables are created when none exist in DB."""
@ -278,6 +280,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
@ -295,7 +298,9 @@ class TestAdvancedChatAppRunnerConversationVariables:
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_session_class.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
@ -326,7 +331,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Verify that all variables were created
assert len(added_items) == 2, "Should have added both variables"
assert mock_session.add_all.called, "Session add_all should have been called"
assert mock_session.commit.called, "Session commit should have been called"
def test_all_variables_exist_no_changes(self):
"""Test that no changes are made when all variables already exist in DB."""
@ -429,6 +433,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
@ -445,7 +450,9 @@ class TestAdvancedChatAppRunnerConversationVariables:
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_session_class.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
@ -465,4 +472,3 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Verify that no variables were added
assert not mock_session.add_all.called, "Session add_all should not have been called"
assert mock_session.commit.called, "Session commit should still be called"

View File

@ -93,6 +93,16 @@ def _patch_common_run_deps(runner: AdvancedChatAppRunner):
scalar=lambda *a, **k: MagicMock(),
),
),
sessionmaker=MagicMock(
return_value=MagicMock(
begin=MagicMock(
return_value=MagicMock(
__enter__=lambda s: MagicMock(scalars=MagicMock(return_value=MagicMock(all=lambda: []))),
__exit__=lambda *a, **k: False,
),
),
),
),
select=MagicMock(),
db=MagicMock(engine=MagicMock()),
RedisChannel=MagicMock(),

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus
@ -610,33 +611,33 @@ class TestWorkflowGenerateTaskPipeline:
def test_database_session_rolls_back_on_error(self, monkeypatch):
pipeline = _make_pipeline()
calls = {"commit": 0, "rollback": 0}
class _Session:
def __init__(self, *args, **kwargs):
_ = args, kwargs
calls = {"enter": 0, "exit_exc": None}
class _BeginContext:
def __enter__(self):
return self
calls["enter"] += 1
return MagicMock()
def __exit__(self, exc_type, exc, tb):
calls["exit_exc"] = exc_type
return False
def commit(self):
calls["commit"] += 1
class _Sessionmaker:
def __init__(self, *args, **kwargs):
pass
def rollback(self):
calls["rollback"] += 1
def begin(self):
return _BeginContext()
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session)
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.sessionmaker", _Sessionmaker)
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object()))
with pytest.raises(RuntimeError, match="db error"):
with pipeline._database_session():
raise RuntimeError("db error")
assert calls["commit"] == 0
assert calls["rollback"] == 1
assert calls["enter"] == 1
assert calls["exit_exc"] is RuntimeError
def test_node_retry_and_started_handlers_cover_none_and_value(self):
pipeline = _make_pipeline()

View File

@ -71,7 +71,7 @@ def test_vector_methods_delegate_to_underlying_implementation():
assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value
vector.delete()
runner._create_collection_if_not_exists.assert_called_once_with(2)
runner.create_collection_if_not_exists.assert_called_once_with(2)
runner.add_texts.assert_any_call(texts, [[0.1, 0.2]])
runner.delete_by_ids.assert_called_once_with(["d1"])
runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1")

View File

@ -249,7 +249,7 @@ def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
vector._client = MagicMock()
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404)
vector._create_collection_if_not_exists(embedding_dimension=1024)
vector.create_collection_if_not_exists(embedding_dimension=1024)
vector._client.create_collection.assert_called_once()
openapi_module.redis_client.set.assert_called_once()
@ -268,7 +268,7 @@ def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
vector.config = _config()
vector._client = MagicMock()
vector._create_collection_if_not_exists(embedding_dimension=1024)
vector.create_collection_if_not_exists(embedding_dimension=1024)
vector._client.describe_collection.assert_not_called()
vector._client.create_collection.assert_not_called()
@ -290,7 +290,7 @@ def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500)
with pytest.raises(ValueError, match="failed to create collection collection_1"):
vector._create_collection_if_not_exists(embedding_dimension=512)
vector.create_collection_if_not_exists(embedding_dimension=512)
def test_openapi_add_delete_and_search_methods(monkeypatch):

View File

@ -374,7 +374,7 @@ def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeyp
vector._get_cursor = _cursor_context
vector._create_collection_if_not_exists(embedding_dimension=3)
vector.create_collection_if_not_exists(embedding_dimension=3)
assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list)
assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list)
@ -404,7 +404,7 @@ def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypat
vector._get_cursor = _cursor_context
with pytest.raises(RuntimeError, match="permission denied"):
vector._create_collection_if_not_exists(embedding_dimension=3)
vector.create_collection_if_not_exists(embedding_dimension=3)
def test_delete_methods_raise_when_error_is_not_missing_table():

View File

@ -53,6 +53,31 @@ def _session_factory(calls, execute_results=None):
return _session
class _FakeBeginContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return None
def _sessionmaker_factory(calls, execute_results=None):
def _sessionmaker(*args, **kwargs):
session = _FakeSessionContext(calls=calls, execute_results=execute_results)
return MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
return _sessionmaker
def _patch_both(monkeypatch, module, calls, execute_results=None):
"""Patch both Session and sessionmaker on the module with the same call tracker."""
monkeypatch.setattr(module, "Session", _session_factory(calls, execute_results))
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(calls, execute_results))
@pytest.fixture
def pgvecto_module(monkeypatch):
for name, module in _build_fake_pgvecto_modules().items():
@ -105,7 +130,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
_patch_both(monkeypatch, module, session_calls)
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
vector.create_collection = MagicMock()
@ -124,7 +149,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
_patch_both(monkeypatch, module, session_calls)
lock = MagicMock()
lock.__enter__.return_value = None
@ -151,10 +176,10 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])]
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
_patch_both(monkeypatch, module, init_calls)
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results)))
_patch_both(monkeypatch, module, runtime_calls, execute_results=list(execute_results))
class _InsertBuilder:
def __init__(self, table):
@ -179,6 +204,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
"Session",
_session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]),
)
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
monkeypatch.setattr(
@ -204,12 +230,13 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
],
),
)
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
vector.delete_by_ids(["doc-1"])
assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls)
assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls)
runtime_calls.clear()
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()]))
_patch_both(monkeypatch, module, runtime_calls, execute_results=[MagicMock()])
vector.delete()
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
@ -218,7 +245,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
module, _ = pgvecto_module
init_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
_patch_both(monkeypatch, module, init_calls)
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
runtime_calls = []
@ -277,7 +304,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
(SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1),
(SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8),
]
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows]))
_patch_both(monkeypatch, module, runtime_calls, execute_results=[rows])
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
assert len(docs) == 1

View File

@ -4909,15 +4909,17 @@ class TestInternalHooksCoverage:
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False
sessionmaker_ctx = MagicMock()
sessionmaker_ctx.begin.return_value = session_ctx
with (
patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())),
patch("core.rag.retrieval.dataset_retrieval.Session", return_value=session_ctx),
patch("core.rag.retrieval.dataset_retrieval.sessionmaker", return_value=sessionmaker_ctx),
patch.object(retrieval, "_send_trace_task") as mock_trace,
):
retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1})
query.update.assert_called_once()
session.commit.assert_called_once()
mock_trace.assert_called_once()
def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None:

View File

@ -129,7 +129,7 @@ def test_get_file_binary_returns_none_when_not_found() -> None:
# Arrange
manager = ToolFileManager()
session = Mock()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
# Act
with _patch_session_factory(session):
@ -144,7 +144,7 @@ def test_get_file_binary_returns_bytes_when_found() -> None:
manager = ToolFileManager()
tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain")
session = Mock()
session.query.return_value.where.return_value.first.return_value = tool_file
session.scalar.return_value = tool_file
# Act
with patch("core.tools.tool_file_manager.storage") as storage:
@ -160,11 +160,7 @@ def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None:
# Arrange
manager = ToolFileManager()
session = Mock()
first_query = Mock()
second_query = Mock()
first_query.where.return_value.first.return_value = None
second_query.where.return_value.first.return_value = None
session.query.side_effect = [first_query, second_query]
session.scalar.side_effect = [None, None]
# Act
with _patch_session_factory(session):
@ -179,11 +175,7 @@ def test_get_file_binary_by_message_file_id_when_url_is_none() -> None:
manager = ToolFileManager()
message_file = SimpleNamespace(url=None)
session = Mock()
first_query = Mock()
second_query = Mock()
first_query.where.return_value.first.return_value = message_file
second_query.where.return_value.first.return_value = None
session.query.side_effect = [first_query, second_query]
session.scalar.side_effect = [message_file, None]
# Act
with _patch_session_factory(session):
@ -199,11 +191,7 @@ def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None:
message_file = SimpleNamespace(url="https://x/files/tools/tool123.png")
tool_file = SimpleNamespace(file_key="k2", mimetype="image/png")
session = Mock()
first_query = Mock()
second_query = Mock()
first_query.where.return_value.first.return_value = message_file
second_query.where.return_value.first.return_value = tool_file
session.query.side_effect = [first_query, second_query]
session.scalar.side_effect = [message_file, tool_file]
# Act
with patch("core.tools.tool_file_manager.storage") as storage:
@ -219,7 +207,7 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None:
# Arrange
manager = ToolFileManager()
session = Mock()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
# Act
with _patch_session_factory(session):
@ -242,7 +230,7 @@ def test_get_file_generator_returns_stream_when_found() -> None:
size=12,
)
session = Mock()
session.query.return_value.where.return_value.first.return_value = tool_file
session.scalar.return_value = tool_file
# Act
with patch("core.tools.tool_file_manager.storage") as storage:

View File

@ -43,7 +43,7 @@ def test_get_db_provider_tool_builds_entity():
controller = _controller()
session = Mock()
workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={})
session.query.return_value.where.return_value.first.return_value = workflow
session.scalar.return_value = workflow
app = SimpleNamespace(id="app-1")
db_provider = SimpleNamespace(
id="provider-1",
@ -136,7 +136,7 @@ def test_from_db_builds_controller():
parameter_configurations=[],
)
session = _mock_session_with_begin()
session.query.return_value.where.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
session.get.side_effect = [app, user]
fake_cm = MagicMock()
fake_cm.__enter__.return_value = session
@ -163,7 +163,7 @@ def test_get_tools_returns_empty_when_provider_missing():
mock_db.engine = object()
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
session = _mock_session_with_begin()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
session_cls.return_value.__enter__.return_value = session
assert controller.get_tools("tenant-1") == []
@ -189,7 +189,7 @@ def test_get_tools_raises_when_app_missing():
mock_db.engine = object()
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
session = _mock_session_with_begin()
session.query.return_value.where.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
session.get.return_value = None
session_cls.return_value.__enter__.return_value = session
with pytest.raises(ValueError, match="app not found"):

View File

@ -110,6 +110,34 @@ class TestFetchMemory:
)
class TestDifyGraphInitContext:
def test_to_graph_init_params_preserves_explicit_values(self):
run_context = {
DIFY_RUN_CONTEXT_KEY: DifyRunContext(
tenant_id="tenant-id",
app_id="app-id",
user_id="user-id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
"extra": "value",
}
graph_config = {"nodes": [], "edges": []}
graph_init_context = node_factory.DifyGraphInitContext(
workflow_id="workflow-id",
graph_config=graph_config,
run_context=run_context,
call_depth=2,
)
result = graph_init_context.to_graph_init_params()
assert result.workflow_id == "workflow-id"
assert result.graph_config == graph_config
assert result.run_context == run_context
assert result.call_depth == 2
class TestDefaultWorkflowCodeExecutor:
def test_execute_delegates_to_code_executor(self, monkeypatch):
executor = node_factory.DefaultWorkflowCodeExecutor()
@ -172,6 +200,23 @@ class TestCodeExecutorJinja2TemplateRenderer:
class TestDifyNodeFactoryInit:
def test_from_graph_init_context_translates_before_init(self):
graph_init_context = MagicMock()
graph_init_context.to_graph_init_params.return_value = sentinel.graph_init_params
with patch.object(node_factory.DifyNodeFactory, "__init__", return_value=None) as init:
factory = node_factory.DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=sentinel.graph_runtime_state,
)
assert isinstance(factory, node_factory.DifyNodeFactory)
graph_init_context.to_graph_init_params.assert_called_once_with()
init.assert_called_once_with(
graph_init_params=sentinel.graph_init_params,
graph_runtime_state=sentinel.graph_runtime_state,
)
def test_init_builds_default_dependencies(self):
graph_init_params = SimpleNamespace(run_context={"context": "value"})
graph_runtime_state = sentinel.graph_runtime_state

View File

@ -349,7 +349,7 @@ class TestWorkflowEntrySingleStepRun:
]
with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(
workflow_entry,
"GraphRuntimeState",
@ -358,7 +358,7 @@ class TestWorkflowEntrySingleStepRun:
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeLLMNode),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
patch.object(workflow_entry, "load_into_variable_pool"),
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
patch.object(
@ -412,12 +412,12 @@ class TestWorkflowEntrySingleStepRun:
raise NotImplementedError
with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool,
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
patch.object(
@ -481,12 +481,12 @@ class TestWorkflowEntrySingleStepRun:
return {"question": ["node", "question"]}
with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeDatasourceNode),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool,
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
patch.object(
@ -541,12 +541,12 @@ class TestWorkflowEntrySingleStepRun:
return "1"
with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
patch.object(workflow_entry, "add_node_inputs_to_pool"),
patch.object(workflow_entry, "load_into_variable_pool"),
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
@ -651,14 +651,18 @@ class TestWorkflowEntryHelpers:
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls,
patch.object(workflow_entry, "add_variables_to_pool") as add_variables_to_pool,
patch.object(
workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params
) as graph_init_params,
workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context
) as graph_init_context_cls,
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(
workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}
) as build_dify_run_context,
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory) as dify_node_factory_cls,
patch.object(
workflow_entry.DifyNodeFactory,
"from_graph_init_context",
return_value=dify_node_factory,
) as dify_node_factory_cls,
patch.object(
workflow_entry.WorkflowEntry,
"mapping_user_inputs_to_variable_pool",
@ -688,7 +692,7 @@ class TestWorkflowEntryHelpers:
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_params.assert_called_once_with(
graph_init_context_cls.assert_called_once_with(
workflow_id="",
graph_config=workflow_entry.WorkflowEntry._create_single_node_graph(
"node-id", {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, "title": "Node"}
@ -697,7 +701,7 @@ class TestWorkflowEntryHelpers:
call_depth=0,
)
dify_node_factory_cls.assert_called_once_with(
graph_init_params=sentinel.graph_init_params,
graph_init_context=sentinel.graph_init_context,
graph_runtime_state=sentinel.graph_runtime_state,
)
mapping_user_inputs_to_variable_pool.assert_called_once_with(
@ -734,11 +738,15 @@ class TestWorkflowEntryHelpers:
patch.object(workflow_entry, "default_system_variables", return_value=sentinel.system_variables),
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool),
patch.object(workflow_entry, "add_variables_to_pool"),
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory),
patch.object(
workflow_entry.DifyNodeFactory,
"from_graph_init_context",
return_value=dify_node_factory,
),
patch.object(
workflow_entry.WorkflowEntry,
"mapping_user_inputs_to_variable_pool",

View File

@ -1,53 +1,125 @@
from unittest.mock import patch
from redis import RedisError
from redis.retry import Retry
from extensions.ext_redis import redis_fallback
from extensions.ext_redis import (
_get_base_redis_params,
_get_cluster_connection_health_params,
_get_connection_health_params,
redis_fallback,
)
def test_redis_fallback_success():
@redis_fallback(default_return=None)
def test_func():
return "success"
class TestGetConnectionHealthParams:
@patch("extensions.ext_redis.dify_config")
def test_includes_all_health_params(self, mock_config):
mock_config.REDIS_RETRY_RETRIES = 3
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
assert test_func() == "success"
params = _get_connection_health_params()
assert "retry" in params
assert "socket_timeout" in params
assert "socket_connect_timeout" in params
assert "health_check_interval" in params
assert isinstance(params["retry"], Retry)
assert params["retry"]._retries == 3
assert params["socket_timeout"] == 5.0
assert params["socket_connect_timeout"] == 5.0
assert params["health_check_interval"] == 30
def test_redis_fallback_error():
@redis_fallback(default_return="fallback")
def test_func():
raise RedisError("Redis error")
class TestGetClusterConnectionHealthParams:
@patch("extensions.ext_redis.dify_config")
def test_excludes_health_check_interval(self, mock_config):
mock_config.REDIS_RETRY_RETRIES = 3
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
assert test_func() == "fallback"
params = _get_cluster_connection_health_params()
assert "retry" in params
assert "socket_timeout" in params
assert "socket_connect_timeout" in params
assert "health_check_interval" not in params
def test_redis_fallback_none_default():
@redis_fallback()
def test_func():
raise RedisError("Redis error")
class TestGetBaseRedisParams:
@patch("extensions.ext_redis.dify_config")
def test_includes_retry_and_health_params(self, mock_config):
mock_config.REDIS_USERNAME = None
mock_config.REDIS_PASSWORD = None
mock_config.REDIS_DB = 0
mock_config.REDIS_SERIALIZATION_PROTOCOL = 3
mock_config.REDIS_ENABLE_CLIENT_SIDE_CACHE = False
mock_config.REDIS_RETRY_RETRIES = 3
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
assert test_func() is None
params = _get_base_redis_params()
assert "retry" in params
assert isinstance(params["retry"], Retry)
assert params["socket_timeout"] == 5.0
assert params["socket_connect_timeout"] == 5.0
assert params["health_check_interval"] == 30
# Existing params still present
assert params["db"] == 0
assert params["encoding"] == "utf-8"
def test_redis_fallback_with_args():
@redis_fallback(default_return=0)
def test_func(x, y):
raise RedisError("Redis error")
class TestRedisFallback:
def test_redis_fallback_success(self):
@redis_fallback(default_return=None)
def test_func():
return "success"
assert test_func(1, 2) == 0
assert test_func() == "success"
def test_redis_fallback_error(self):
@redis_fallback(default_return="fallback")
def test_func():
raise RedisError("Redis error")
def test_redis_fallback_with_kwargs():
@redis_fallback(default_return={})
def test_func(x=None, y=None):
raise RedisError("Redis error")
assert test_func() == "fallback"
assert test_func(x=1, y=2) == {}
def test_redis_fallback_none_default(self):
@redis_fallback()
def test_func():
raise RedisError("Redis error")
assert test_func() is None
def test_redis_fallback_preserves_function_metadata():
@redis_fallback(default_return=None)
def test_func():
"""Test function docstring"""
pass
def test_redis_fallback_with_args(self):
@redis_fallback(default_return=0)
def test_func(x, y):
raise RedisError("Redis error")
assert test_func.__name__ == "test_func"
assert test_func.__doc__ == "Test function docstring"
assert test_func(1, 2) == 0
def test_redis_fallback_with_kwargs(self):
@redis_fallback(default_return={})
def test_func(x=None, y=None):
raise RedisError("Redis error")
assert test_func(x=1, y=2) == {}
def test_redis_fallback_preserves_function_metadata(self):
@redis_fallback(default_return=None)
def test_func():
"""Test function docstring"""
pass
assert test_func.__name__ == "test_func"
assert test_func.__doc__ == "Test function docstring"

View File

@ -6,12 +6,12 @@ MODULE = "services.plugin.plugin_auto_upgrade_service"
def _patched_session():
"""Patch Session(db.engine) to return a mock session as context manager."""
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
session = MagicMock()
session_cls = MagicMock()
session_cls.return_value.__enter__ = MagicMock(return_value=session)
session_cls.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.Session", session_cls)
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
db_patcher = patch(f"{MODULE}.db")
return patcher, db_patcher, session
@ -20,7 +20,7 @@ class TestGetStrategy:
def test_returns_strategy_when_found(self):
p1, p2, session = _patched_session()
strategy = MagicMock()
session.query.return_value.where.return_value.first.return_value = strategy
session.scalar.return_value = strategy
with p1, p2:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
@ -31,7 +31,7 @@ class TestGetStrategy:
def test_returns_none_when_not_found(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with p1, p2:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
@ -44,9 +44,9 @@ class TestGetStrategy:
class TestChangeStrategy:
def test_creates_new_strategy(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.return_value = MagicMock()
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
@ -61,12 +61,11 @@ class TestChangeStrategy:
assert result is True
session.add.assert_called_once()
session.commit.assert_called_once()
def test_updates_existing_strategy(self):
p1, p2, session = _patched_session()
existing = MagicMock()
session.query.return_value.where.return_value.first.return_value = existing
session.scalar.return_value = existing
with p1, p2:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
@ -86,17 +85,17 @@ class TestChangeStrategy:
assert existing.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
assert existing.exclude_plugins == ["p1"]
assert existing.include_plugins == ["p2"]
session.commit.assert_called_once()
class TestExcludePlugin:
def test_creates_default_strategy_when_none_exists(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with (
p1,
p2,
patch(f"{MODULE}.select"),
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
):
@ -115,9 +114,9 @@ class TestExcludePlugin:
existing = MagicMock()
existing.upgrade_mode = "exclude"
existing.exclude_plugins = ["p-existing"]
session.query.return_value.where.return_value.first.return_value = existing
session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
@ -127,16 +126,15 @@ class TestExcludePlugin:
assert result is True
assert existing.exclude_plugins == ["p-existing", "p-new"]
session.commit.assert_called_once()
def test_removes_from_include_list_in_partial_mode(self):
p1, p2, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "partial"
existing.include_plugins = ["p1", "p2"]
session.query.return_value.where.return_value.first.return_value = existing
session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
@ -151,9 +149,9 @@ class TestExcludePlugin:
p1, p2, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "all"
session.query.return_value.where.return_value.first.return_value = existing
session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
@ -170,9 +168,9 @@ class TestExcludePlugin:
existing = MagicMock()
existing.upgrade_mode = "exclude"
existing.exclude_plugins = ["p1"]
session.query.return_value.where.return_value.first.return_value = existing
session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"

View File

@ -6,12 +6,12 @@ MODULE = "services.plugin.plugin_permission_service"
def _patched_session():
"""Patch Session(db.engine) to return a mock session as context manager."""
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
session = MagicMock()
session_cls = MagicMock()
session_cls.return_value.__enter__ = MagicMock(return_value=session)
session_cls.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.Session", session_cls)
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
db_patcher = patch(f"{MODULE}.db")
return patcher, db_patcher, session
@ -20,7 +20,7 @@ class TestGetPermission:
def test_returns_permission_when_found(self):
p1, p2, session = _patched_session()
permission = MagicMock()
session.query.return_value.where.return_value.first.return_value = permission
session.scalar.return_value = permission
with p1, p2:
from services.plugin.plugin_permission_service import PluginPermissionService
@ -31,7 +31,7 @@ class TestGetPermission:
def test_returns_none_when_not_found(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with p1, p2:
from services.plugin.plugin_permission_service import PluginPermissionService
@ -44,9 +44,9 @@ class TestGetPermission:
class TestChangePermission:
def test_creates_new_permission_when_not_exists(self):
p1, p2, session = _patched_session()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with p1, p2, patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
perm_cls.return_value = MagicMock()
from services.plugin.plugin_permission_service import PluginPermissionService
@ -55,12 +55,11 @@ class TestChangePermission:
)
session.add.assert_called_once()
session.commit.assert_called_once()
def test_updates_existing_permission(self):
p1, p2, session = _patched_session()
existing = MagicMock()
session.query.return_value.where.return_value.first.return_value = existing
session.scalar.return_value = existing
with p1, p2:
from services.plugin.plugin_permission_service import PluginPermissionService
@ -71,5 +70,4 @@ class TestChangePermission:
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
session.commit.assert_called_once()
session.add.assert_not_called()

View File

@ -275,48 +275,46 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -
msg_session_1.query.side_effect = lambda model: (
make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock()
)
msg_session_1.commit.return_value = None
msg_session_2 = MagicMock()
msg_session_2.query.side_effect = lambda model: (
make_query_with_batches([[]]) if model == service_module.Message else MagicMock()
)
msg_session_2.commit.return_value = None
conv_session_1 = MagicMock()
conv_session_1.query.side_effect = lambda model: (
make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock()
)
conv_session_1.commit.return_value = None
conv_session_2 = MagicMock()
conv_session_2.query.side_effect = lambda model: (
make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock()
)
conv_session_2.commit.return_value = None
wal_session_1 = MagicMock()
wal_session_1.query.side_effect = lambda model: (
make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock()
)
wal_session_1.commit.return_value = None
wal_session_2 = MagicMock()
wal_session_2.query.side_effect = lambda model: (
make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock()
)
wal_session_2.commit.return_value = None
session_wrappers = [
_session_wrapper_for_no_autoflush(msg_session_1),
_session_wrapper_for_no_autoflush(msg_session_2),
_session_wrapper_for_no_autoflush(conv_session_1),
_session_wrapper_for_no_autoflush(conv_session_2),
_session_wrapper_for_no_autoflush(wal_session_1),
_session_wrapper_for_no_autoflush(wal_session_2),
_sessionmaker_wrapper_for_begin(msg_session_1),
_sessionmaker_wrapper_for_begin(msg_session_2),
_sessionmaker_wrapper_for_begin(conv_session_1),
_sessionmaker_wrapper_for_begin(conv_session_2),
_sessionmaker_wrapper_for_begin(wal_session_1),
_sessionmaker_wrapper_for_begin(wal_session_2),
]
monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0))
def fake_sessionmaker(*args, **kwargs):
if kwargs.get("autoflush") is False:
return session_wrappers.pop(0)
return object()
monkeypatch.setattr(service_module, "sessionmaker", fake_sessionmaker)
def fake_select(*_args, **_kwargs):
stmt = MagicMock()
@ -333,8 +331,6 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -
run_repo = MagicMock()
run_repo.get_expired_runs_batch.side_effect = [[SimpleNamespace(id="wr-1", to_dict=lambda: {"id": "wr-1"})], []]
run_repo.delete_runs_by_ids.return_value = 1
monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object())
monkeypatch.setattr(
service_module.DifyAPIRepositoryFactory,
"create_api_workflow_node_execution_repository",
@ -574,13 +570,18 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte
q_empty.limit.return_value = q_empty
q_empty.all.return_value = []
empty_session.query.return_value = q_empty
empty_session.commit.return_value = None
session_wrappers = [
_session_wrapper_for_no_autoflush(empty_session),
_session_wrapper_for_no_autoflush(empty_session),
_session_wrapper_for_no_autoflush(empty_session),
_sessionmaker_wrapper_for_begin(empty_session),
_sessionmaker_wrapper_for_begin(empty_session),
_sessionmaker_wrapper_for_begin(empty_session),
]
monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0))
def fake_sessionmaker(*args, **kwargs):
if kwargs.get("autoflush") is False:
return session_wrappers.pop(0)
return object()
monkeypatch.setattr(service_module, "sessionmaker", fake_sessionmaker)
def fake_select(*_args, **_kwargs):
stmt = MagicMock()
@ -606,8 +607,6 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte
[],
]
run_repo.delete_runs_by_ids.return_value = 2
monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object())
monkeypatch.setattr(
service_module.DifyAPIRepositoryFactory,
"create_api_workflow_node_execution_repository",

View File

@ -40,7 +40,10 @@ class TestDatasourceProviderService:
q returns itself for .filter_by(), .order_by(), .where() so any
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
"""
with patch("services.datasource_provider_service.Session") as mock_cls:
with (
patch("services.datasource_provider_service.Session") as mock_cls,
patch("services.datasource_provider_service.sessionmaker") as mock_sm,
):
sess = MagicMock(spec=Session)
q = MagicMock()
@ -63,6 +66,8 @@ class TestDatasourceProviderService:
mock_cls.return_value.__enter__.return_value = sess
mock_cls.return_value.no_autoflush.__enter__.return_value = sess
mock_sm.return_value.begin.return_value.__enter__.return_value = sess
mock_sm.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
yield sess
@ -266,7 +271,6 @@ class TestDatasourceProviderService:
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}),
):
service.get_datasource_credentials("t1", "prov", "org/plug")
mock_db_session.commit.assert_called_once()
def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user):
"""API key credentials with expires_at=-1 skip refresh and return directly."""
@ -333,7 +337,6 @@ class TestDatasourceProviderService:
p.name = "same"
mock_db_session.query().first.return_value = p
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
mock_db_session.commit.assert_not_called()
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
@ -352,7 +355,6 @@ class TestDatasourceProviderService:
mock_db_session.query().count.return_value = 0
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
assert p.name == "new_name"
mock_db_session.commit.assert_called_once()
# -----------------------------------------------------------------------
# set_default_datasource_provider (lines 277-303)
@ -370,7 +372,6 @@ class TestDatasourceProviderService:
mock_db_session.query().first.return_value = target
service.set_default_datasource_provider("t1", make_id(), "new-id")
assert target.is_default is True
mock_db_session.commit.assert_called_once()
# -----------------------------------------------------------------------
# get_oauth_encrypter (lines 404-420)
@ -460,7 +461,6 @@ class TestDatasourceProviderService:
with patch.object(service, "extract_secret_variables", return_value=[]):
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
"""Conflict on name results in auto-incremented name, not an error."""
@ -512,7 +512,6 @@ class TestDatasourceProviderService:
mock_db_session.query().count.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=[]):
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
mock_db_session.commit.assert_called_once()
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
@ -523,7 +522,6 @@ class TestDatasourceProviderService:
service.reauthorize_datasource_oauth_provider(
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
)
mock_db_session.commit.assert_called_once()
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
@ -571,7 +569,6 @@ class TestDatasourceProviderService:
):
service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"})
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
mock_db_session.query().count.return_value = 0
@ -747,7 +744,6 @@ class TestDatasourceProviderService:
# encrypter must have been called with the new secret value
self._enc.encrypt_token.assert_called()
# commit must be called exactly once
mock_db_session.commit.assert_called_once()
# -----------------------------------------------------------------------
# remove_datasource_credentials (lines 980-997)
@ -758,7 +754,6 @@ class TestDatasourceProviderService:
mock_db_session.scalar.return_value = p
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
mock_db_session.delete.assert_called_once_with(p)
mock_db_session.commit.assert_called_once()
def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session):
"""No error raised; no delete called when record doesn't exist (lines 994 branch)."""

View File

@ -1,628 +0,0 @@
"""
Comprehensive unit tests for RecommendedAppService.
This test suite provides complete coverage of recommended app operations in Dify,
following TDD principles with the Arrange-Act-Assert pattern.
## Test Coverage
### 1. Get Recommended Apps and Categories (TestRecommendedAppServiceGetApps)
Tests fetching recommended apps with categories:
- Successful retrieval with recommended apps
- Fallback to builtin when no recommended apps
- Different language support
- Factory mode selection (remote, builtin, db)
- Empty result handling
### 2. Get Recommend App Detail (TestRecommendedAppServiceGetDetail)
Tests fetching individual app details:
- Successful app detail retrieval
- Different factory modes
- App not found scenarios
- Language-specific details
## Testing Approach
- **Mocking Strategy**: All external dependencies (dify_config, RecommendAppRetrievalFactory)
are mocked for fast, isolated unit tests
- **Factory Pattern**: Tests verify correct factory selection based on mode
- **Fixtures**: Mock objects are configured per test method
- **Assertions**: Each test verifies return values and factory method calls
## Key Concepts
**Factory Modes:**
- remote: Fetch from remote API
- builtin: Use built-in templates
- db: Fetch from database
**Fallback Logic:**
- If remote/db returns no apps, fallback to builtin en-US templates
- Ensures users always see some recommended apps
"""
from unittest.mock import MagicMock, patch
import pytest
from services.recommended_app_service import RecommendedAppService
class RecommendedAppServiceTestDataFactory:
"""
Factory for creating test data and mock objects.
Provides reusable methods to create consistent mock objects for testing
recommended app operations.
"""
@staticmethod
def create_recommended_apps_response(
recommended_apps: list[dict] | None = None,
categories: list[str] | None = None,
) -> dict:
"""
Create a mock response for recommended apps.
Args:
recommended_apps: List of recommended app dictionaries
categories: List of category names
Returns:
Dictionary with recommended_apps and categories
"""
if recommended_apps is None:
recommended_apps = [
{
"id": "app-1",
"name": "Test App 1",
"description": "Test description 1",
"category": "productivity",
},
{
"id": "app-2",
"name": "Test App 2",
"description": "Test description 2",
"category": "communication",
},
]
if categories is None:
categories = ["productivity", "communication", "utilities"]
return {
"recommended_apps": recommended_apps,
"categories": categories,
}
@staticmethod
def create_app_detail_response(
app_id: str = "app-123",
name: str = "Test App",
description: str = "Test description",
**kwargs,
) -> dict:
"""
Create a mock response for app detail.
Args:
app_id: App identifier
name: App name
description: App description
**kwargs: Additional fields
Returns:
Dictionary with app details
"""
detail = {
"id": app_id,
"name": name,
"description": description,
"category": kwargs.get("category", "productivity"),
"icon": kwargs.get("icon", "🚀"),
"model_config": kwargs.get("model_config", {}),
}
detail.update(kwargs)
return detail
@pytest.fixture
def factory():
"""Provide the test data factory to all tests."""
return RecommendedAppServiceTestDataFactory
class TestRecommendedAppServiceGetApps:
"""Test get_recommended_apps_and_categories operations."""
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_success_with_apps(self, mock_config, mock_factory_class, factory):
"""Test successful retrieval of recommended apps when apps are returned."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
expected_response = factory.create_recommended_apps_response()
# Mock factory and retrieval instance
mock_retrieval_instance = MagicMock()
mock_retrieval_instance.get_recommended_apps_and_categories.return_value = expected_response
mock_factory = MagicMock()
mock_factory.return_value = mock_retrieval_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
# Assert
assert result == expected_response
assert len(result["recommended_apps"]) == 2
assert len(result["categories"]) == 3
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
mock_retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class, factory):
"""Test fallback to builtin when no recommended apps are returned."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
# Remote returns empty recommended_apps
empty_response = {"recommended_apps": [], "categories": []}
# Builtin fallback response
builtin_response = factory.create_recommended_apps_response(
recommended_apps=[{"id": "builtin-1", "name": "Builtin App", "category": "default"}]
)
# Mock remote retrieval instance (returns empty)
mock_remote_instance = MagicMock()
mock_remote_instance.get_recommended_apps_and_categories.return_value = empty_response
mock_remote_factory = MagicMock()
mock_remote_factory.return_value = mock_remote_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_remote_factory
# Mock builtin retrieval instance
mock_builtin_instance = MagicMock()
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN")
# Assert
assert result == builtin_response
assert len(result["recommended_apps"]) == 1
assert result["recommended_apps"][0]["id"] == "builtin-1"
# Verify fallback was called with en-US (hardcoded)
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class, factory):
"""Test fallback when recommended_apps key is None."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db"
# Response with None recommended_apps
none_response = {"recommended_apps": None, "categories": ["test"]}
# Builtin fallback response
builtin_response = factory.create_recommended_apps_response()
# Mock db retrieval instance (returns None)
mock_db_instance = MagicMock()
mock_db_instance.get_recommended_apps_and_categories.return_value = none_response
mock_db_factory = MagicMock()
mock_db_factory.return_value = mock_db_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_db_factory
# Mock builtin retrieval instance
mock_builtin_instance = MagicMock()
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
# Assert
assert result == builtin_response
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_with_different_languages(self, mock_config, mock_factory_class, factory):
"""Test retrieval with different language codes."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
languages = ["en-US", "zh-CN", "ja-JP", "fr-FR"]
for language in languages:
# Create language-specific response
lang_response = factory.create_recommended_apps_response(
recommended_apps=[{"id": f"app-{language}", "name": f"App {language}", "category": "test"}]
)
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommended_apps_and_categories.return_value = lang_response
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommended_apps_and_categories(language)
# Assert
assert result["recommended_apps"][0]["id"] == f"app-{language}"
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_uses_correct_factory_mode(self, mock_config, mock_factory_class, factory):
"""Test that correct factory is selected based on mode."""
# Arrange
modes = ["remote", "builtin", "db"]
for mode in modes:
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
response = factory.create_recommended_apps_response()
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommended_apps_and_categories.return_value = response
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
RecommendedAppService.get_recommended_apps_and_categories("en-US")
# Assert
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
class TestRecommendedAppServiceGetDetail:
"""Test get_recommend_app_detail operations."""
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_success(self, mock_config, mock_factory_class, factory):
"""Test successful retrieval of app detail."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
app_id = "app-123"
expected_detail = factory.create_app_detail_response(
app_id=app_id,
name="Productivity App",
description="A great productivity app",
category="productivity",
)
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = expected_detail
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result == expected_detail
assert result["id"] == app_id
assert result["name"] == "Productivity App"
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_with_different_modes(self, mock_config, mock_factory_class, factory):
"""Test app detail retrieval with different factory modes."""
# Arrange
modes = ["remote", "builtin", "db"]
app_id = "test-app"
for mode in modes:
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
detail = factory.create_app_detail_response(app_id=app_id, name=f"App from {mode}")
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = detail
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result["name"] == f"App from {mode}"
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_returns_none_when_not_found(self, mock_config, mock_factory_class, factory):
"""Test that None is returned when app is not found."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
app_id = "nonexistent-app"
# Mock retrieval instance returning None
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = None
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result is None
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_returns_empty_dict(self, mock_config, mock_factory_class, factory):
"""Test handling of empty dict response."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
app_id = "app-empty"
# Mock retrieval instance returning empty dict
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = {}
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result == {}
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_with_complex_model_config(self, mock_config, mock_factory_class, factory):
"""Test app detail with complex model configuration."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
app_id = "complex-app"
complex_model_config = {
"provider": "openai",
"model": "gpt-4",
"parameters": {
"temperature": 0.7,
"max_tokens": 2000,
"top_p": 1.0,
},
}
expected_detail = factory.create_app_detail_response(
app_id=app_id,
name="Complex App",
model_config=complex_model_config,
workflows=["workflow-1", "workflow-2"],
tools=["tool-1", "tool-2", "tool-3"],
)
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = expected_detail
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
# Assert
assert result["model_config"] == complex_model_config
assert len(result["workflows"]) == 2
assert len(result["tools"]) == 3
# === Merged from test_recommended_app_service_additional.py ===
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from services import recommended_app_service as service_module
from services.recommended_app_service import RecommendedAppService
def _recommendation_detail(result: dict[str, Any] | None) -> dict[str, Any]:
return cast(dict[str, Any], result)
@pytest.fixture
def mocked_db_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
# Arrange
session = MagicMock()
monkeypatch.setattr(service_module, "db", SimpleNamespace(session=session))
# Assert
return session
def _mock_factory_for_apps(
monkeypatch: pytest.MonkeyPatch,
*,
mode: str,
result: dict[str, Any],
fallback_result: dict[str, Any] | None = None,
) -> tuple[MagicMock, MagicMock]:
retrieval_instance = MagicMock()
retrieval_instance.get_recommended_apps_and_categories.return_value = result
retrieval_factory = MagicMock(return_value=retrieval_instance)
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", mode, raising=False)
monkeypatch.setattr(
service_module.RecommendAppRetrievalFactory,
"get_recommend_app_factory",
MagicMock(return_value=retrieval_factory),
)
builtin_instance = MagicMock()
if fallback_result is not None:
builtin_instance.fetch_recommended_apps_from_builtin.return_value = fallback_result
monkeypatch.setattr(
service_module.RecommendAppRetrievalFactory,
"get_buildin_recommend_app_retrieval",
MagicMock(return_value=builtin_instance),
)
return retrieval_instance, builtin_instance
def test_get_recommended_apps_and_categories_should_not_query_trial_table_when_trial_feature_disabled(
monkeypatch: pytest.MonkeyPatch,
mocked_db_session: MagicMock,
) -> None:
# Arrange
expected = {"recommended_apps": [{"app_id": "app-1"}], "categories": ["all"]}
retrieval_instance, builtin_instance = _mock_factory_for_apps(
monkeypatch,
mode="remote",
result=expected,
)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",
MagicMock(return_value=SimpleNamespace(enable_trial_app=False)),
)
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
# Assert
assert result == expected
retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
builtin_instance.fetch_recommended_apps_from_builtin.assert_not_called()
mocked_db_session.scalar.assert_not_called()
def test_get_recommended_apps_and_categories_should_fallback_and_enrich_can_trial_when_trial_feature_enabled(
monkeypatch: pytest.MonkeyPatch,
mocked_db_session: MagicMock,
) -> None:
# Arrange
remote_result = {"recommended_apps": [], "categories": []}
fallback_result = {"recommended_apps": [{"app_id": "app-1"}, {"app_id": "app-2"}], "categories": ["all"]}
_, builtin_instance = _mock_factory_for_apps(
monkeypatch,
mode="remote",
result=remote_result,
fallback_result=fallback_result,
)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
)
mocked_db_session.scalar.side_effect = [SimpleNamespace(id="trial-app"), None]
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP")
# Assert
builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
assert result["recommended_apps"][0]["can_trial"] is True
assert result["recommended_apps"][1]["can_trial"] is False
assert mocked_db_session.scalar.call_count == 2
@pytest.mark.parametrize(
("trial_query_result", "expected_can_trial"),
[
(SimpleNamespace(id="trial"), True),
(None, False),
],
)
def test_get_recommend_app_detail_should_set_can_trial_when_trial_feature_enabled(
monkeypatch: pytest.MonkeyPatch,
mocked_db_session: MagicMock,
trial_query_result: Any,
expected_can_trial: bool,
) -> None:
# Arrange
detail = {"id": "app-1", "name": "Test App"}
retrieval_instance = MagicMock()
retrieval_instance.get_recommend_app_detail.return_value = detail
retrieval_factory = MagicMock(return_value=retrieval_instance)
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", "remote", raising=False)
monkeypatch.setattr(
service_module.RecommendAppRetrievalFactory,
"get_recommend_app_factory",
MagicMock(return_value=retrieval_factory),
)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
)
mocked_db_session.scalar.return_value = trial_query_result
# Act
result = cast(dict[str, Any], RecommendedAppService.get_recommend_app_detail("app-1"))
# Assert
assert result["id"] == "app-1"
assert result["can_trial"] is expected_can_trial
mocked_db_session.scalar.assert_called_once()
def test_add_trial_app_record_should_increment_count_when_existing_record_found(
mocked_db_session: MagicMock,
) -> None:
# Arrange
existing_record = SimpleNamespace(count=3)
mocked_db_session.scalar.return_value = existing_record
# Act
RecommendedAppService.add_trial_app_record("app-1", "account-1")
# Assert
assert existing_record.count == 4
mocked_db_session.scalar.assert_called_once()
mocked_db_session.commit.assert_called_once()
mocked_db_session.add.assert_not_called()
def test_add_trial_app_record_should_create_new_record_when_no_existing_record(
mocked_db_session: MagicMock,
) -> None:
# Arrange
mocked_db_session.scalar.return_value = None
# Act
RecommendedAppService.add_trial_app_record("app-2", "account-2")
# Assert
mocked_db_session.scalar.assert_called_once()
mocked_db_session.add.assert_called_once()
added = mocked_db_session.add.call_args.args[0]
assert added.app_id == "app-2"
assert added.account_id == "account-2"
assert added.count == 1
mocked_db_session.commit.assert_called_once()

View File

@ -63,6 +63,12 @@ def mock_session(mocker: MockerFixture) -> MagicMock:
mock_session_cm.__enter__.return_value = mock_session_instance
mock_session_cm.__exit__.return_value = False
mocker.patch("services.trigger.trigger_provider_service.Session", return_value=mock_session_cm)
mock_begin_cm = MagicMock()
mock_begin_cm.__enter__.return_value = mock_session_instance
mock_begin_cm.__exit__.return_value = False
mock_sessionmaker_instance = MagicMock()
mock_sessionmaker_instance.begin.return_value = mock_begin_cm
mocker.patch("services.trigger.trigger_provider_service.sessionmaker", return_value=mock_sessionmaker_instance)
return mock_session_instance
@ -212,7 +218,6 @@ def test_add_trigger_subscription_should_create_subscription_successfully_for_ap
# Assert
assert result["result"] == "success"
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorized_type(
@ -406,7 +411,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache(
assert subscription.credentials == {"api_key": "new-key"}
assert subscription.credential_expires_at == 100
assert subscription.expires_at == 200
mock_session.commit.assert_called_once()
mock_delete_cache.assert_called_once()
@ -593,7 +598,7 @@ def test_refresh_oauth_token_should_refresh_and_persist_new_credentials(
assert result == {"result": "success", "expires_at": 12345}
assert subscription.credentials == {"access_token": "new"}
assert subscription.credential_expires_at == 12345
mock_session.commit.assert_called_once()
cache.delete.assert_called_once()
@ -664,7 +669,7 @@ def test_refresh_subscription_should_refresh_and_persist_properties(
assert result == {"result": "success", "expires_at": 999}
assert subscription.properties == {"p": "new-enc"}
assert subscription.expires_at == 999
mock_session.commit.assert_called_once()
prop_cache.delete.assert_called_once()
@ -838,7 +843,6 @@ def test_save_custom_oauth_client_params_should_create_record_and_clear_params_w
assert fake_model.encrypted_oauth_params == "{}"
assert fake_model.enabled is True
mock_session.add.assert_called_once_with(fake_model)
mock_session.commit.assert_called_once()
def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_cache(
@ -870,7 +874,6 @@ def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_c
assert result == {"result": "success"}
assert json.loads(custom_client.encrypted_oauth_params) == {"client_id": "new-id"}
cache.delete.assert_called_once()
mock_session.commit.assert_called_once()
def test_get_custom_oauth_client_params_should_return_empty_when_record_missing(
@ -921,7 +924,6 @@ def test_delete_custom_oauth_client_params_should_delete_record_and_commit(
# Assert
assert result == {"result": "success"}
mock_session.commit.assert_called_once()
@pytest.mark.parametrize("exists", [True, False])

View File

@ -617,6 +617,20 @@ class _SessionContext:
return False
class _SessionmakerContext:
def __init__(self, session: Any) -> None:
self._session = session
def begin(self) -> "_SessionmakerContext":
return self
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
@pytest.fixture
def flask_app() -> Flask:
return Flask(__name__)
@ -625,6 +639,7 @@ def flask_app() -> Flask:
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
@ -1240,7 +1255,6 @@ def test_sync_webhook_relationships_should_create_missing_records_and_delete_sta
# Assert
assert len(fake_session.added) == 1
assert len(fake_session.deleted) == 1
assert fake_session.commit_count == 2
redis_set_mock.assert_called_once()
redis_delete_mock.assert_called_once()
lock.release.assert_called_once()

Some files were not shown because too many files have changed in this diff Show More