mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 14:14:17 +08:00
Merge branch 'main' into deploy/dev
This commit is contained in:
commit
d6c3df33c1
@ -71,6 +71,13 @@ REDIS_USE_CLUSTERS=false
|
||||
REDIS_CLUSTERS=
|
||||
REDIS_CLUSTERS_PASSWORD=
|
||||
|
||||
REDIS_RETRY_RETRIES=3
|
||||
REDIS_RETRY_BACKOFF_BASE=1.0
|
||||
REDIS_RETRY_BACKOFF_CAP=10.0
|
||||
REDIS_SOCKET_TIMEOUT=5.0
|
||||
REDIS_SOCKET_CONNECT_TIMEOUT=5.0
|
||||
REDIS_HEALTH_CHECK_INTERVAL=30
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
|
||||
CELERY_BACKEND=redis
|
||||
@ -102,6 +109,7 @@ S3_BUCKET_NAME=your-bucket-name
|
||||
S3_ACCESS_KEY=your-access-key
|
||||
S3_SECRET_KEY=your-secret-key
|
||||
S3_REGION=your-region
|
||||
S3_ADDRESS_STYLE=auto
|
||||
|
||||
# Workflow run and Conversation archive storage (S3-compatible)
|
||||
ARCHIVE_STORAGE_ENABLED=false
|
||||
|
||||
31
api/configs/middleware/cache/redis_config.py
vendored
31
api/configs/middleware/cache/redis_config.py
vendored
@ -117,6 +117,37 @@ class RedisConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_RETRY_RETRIES: NonNegativeInt = Field(
|
||||
description="Maximum number of retries per Redis command on "
|
||||
"transient failures (ConnectionError, TimeoutError, socket.timeout)",
|
||||
default=3,
|
||||
)
|
||||
|
||||
REDIS_RETRY_BACKOFF_BASE: PositiveFloat = Field(
|
||||
description="Base delay in seconds for exponential backoff between retries",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
REDIS_RETRY_BACKOFF_CAP: PositiveFloat = Field(
|
||||
description="Maximum backoff delay in seconds between retries",
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
REDIS_SOCKET_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Socket timeout in seconds for Redis read/write operations",
|
||||
default=5.0,
|
||||
)
|
||||
|
||||
REDIS_SOCKET_CONNECT_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Socket timeout in seconds for Redis connection establishment",
|
||||
default=5.0,
|
||||
)
|
||||
|
||||
REDIS_HEALTH_CHECK_INTERVAL: NonNegativeInt = Field(
|
||||
description="Interval in seconds between Redis connection health checks (0 to disable)",
|
||||
default=30,
|
||||
)
|
||||
|
||||
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
|
||||
@classmethod
|
||||
def _empty_string_to_none_for_max_conns(cls, v):
|
||||
|
||||
@ -48,11 +48,27 @@ class SavedMessageCreatePayload(BaseModel):
|
||||
# --- Workflow schemas ---
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
# --- Audio schemas ---
|
||||
|
||||
|
||||
|
||||
@ -34,9 +34,10 @@ from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportMode
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DataSource,
|
||||
InfoList,
|
||||
|
||||
@ -17,8 +17,9 @@ from fields.app_fields import (
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
@ -92,11 +93,13 @@ class AppImportApi(Resource):
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
match status:
|
||||
case ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
case ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports/<string:import_id>/confirm")
|
||||
|
||||
@ -14,6 +14,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow_run import workflow_run_node_execution_model
|
||||
@ -142,10 +143,6 @@ class PublishWorkflowPayload(BaseModel):
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class ConvertToWorkflowPayload(BaseModel):
|
||||
name: str | None = None
|
||||
icon_type: str | None = None
|
||||
@ -153,18 +150,6 @@ class ConvertToWorkflowPayload(BaseModel):
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
@ -384,24 +384,27 @@ class VariableApi(Resource):
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
match variable.value_type:
|
||||
case SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
case SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=app_model.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
|
||||
@ -223,24 +223,27 @@ class RagPipelineVariableApi(Resource):
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
match variable.value_type:
|
||||
case SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
case SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
|
||||
@ -19,7 +19,7 @@ from fields.rag_pipeline_fields import (
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
@ -83,11 +83,13 @@ class RagPipelineImportApi(Resource):
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
match status:
|
||||
case ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
case ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
@ -94,22 +95,6 @@ class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
|
||||
original_document_id: str | None = None
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class NodeIdQuery(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
@ -168,12 +168,13 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
match app.mode:
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
case AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
case _:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
|
||||
@ -18,7 +18,8 @@ from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App
|
||||
from models.account import AccountStatus
|
||||
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.entities.dsl_entities import ImportMode, ImportStatus
|
||||
|
||||
|
||||
class InnerAppDSLImportPayload(BaseModel):
|
||||
|
||||
@ -138,12 +138,15 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
||||
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
if auth_type == WebAppAuthType.PUBLIC:
|
||||
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
|
||||
elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
|
||||
raise WebAppAuthRequiredError("Please login as external user.")
|
||||
elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
|
||||
raise WebAppAuthRequiredError("Please login as internal user.")
|
||||
match auth_type:
|
||||
case WebAppAuthType.PUBLIC:
|
||||
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
|
||||
case WebAppAuthType.EXTERNAL:
|
||||
if user_auth_type != "external":
|
||||
raise WebAppAuthRequiredError("Please login as external user.")
|
||||
case WebAppAuthType.INTERNAL:
|
||||
if user_auth_type != "internal":
|
||||
raise WebAppAuthRequiredError("Please login as internal user.")
|
||||
|
||||
end_user = None
|
||||
if end_user_id:
|
||||
|
||||
@ -72,12 +72,13 @@ class WorkflowEventsApi(WebApiResource):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app_mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
match app_mode:
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
case AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
case _:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from graphon.variable_loader import VariableLoader
|
||||
from graphon.variables.variables import Variable
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -363,7 +363,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
:return: List of conversation variables ready for use
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
existing_variables = self._load_existing_conversation_variables(session)
|
||||
|
||||
if not existing_variables:
|
||||
@ -376,7 +376,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
# Convert to Variable objects for use in the workflow
|
||||
conversation_variables = [var.to_variable() for var in existing_variables]
|
||||
|
||||
session.commit()
|
||||
return cast(list[Variable], conversation_variables)
|
||||
|
||||
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
|
||||
|
||||
@ -16,7 +16,7 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -328,13 +328,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
@contextmanager
|
||||
def _database_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
yield session
|
||||
|
||||
def _ensure_workflow_initialized(self):
|
||||
"""Fluent validation for workflow state."""
|
||||
|
||||
@ -2,7 +2,6 @@ import logging
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.enums import WorkflowType
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
@ -22,7 +21,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id
|
||||
from core.workflow.node_factory import DifyGraphInitContext, DifyNodeFactory, get_default_root_node_id
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@ -265,22 +264,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
# graph_config["nodes"] = real_run_nodes
|
||||
# graph_config["edges"] = real_edges
|
||||
# init graph
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
# Create explicit graph init context for Graph.init.
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
run_context=run_context,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if start_node_id is None:
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Union
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -252,13 +252,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
@contextmanager
|
||||
def _database_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
yield session
|
||||
|
||||
def _ensure_workflow_initialized(self):
|
||||
"""Fluent validation for workflow state."""
|
||||
|
||||
@ -3,7 +3,6 @@ import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.entities.graph_config import NodeConfigDictAdapter
|
||||
from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.graph import Graph
|
||||
@ -67,7 +66,12 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||
from core.workflow.node_factory import (
|
||||
DifyGraphInitContext,
|
||||
DifyNodeFactory,
|
||||
get_default_root_node_id,
|
||||
resolve_workflow_node_class,
|
||||
)
|
||||
from core.workflow.system_variables import (
|
||||
build_bootstrap_variables,
|
||||
default_system_variables,
|
||||
@ -127,24 +131,25 @@ class WorkflowBasedAppRunner:
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
# Create explicit graph init context for Graph.init.
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
run_context=run_context,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Use the provided graph_runtime_state for consistent state management
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -289,22 +294,23 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs]
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
# Create explicit graph init context for Graph.init.
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
run_context=run_context,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import ModelStatus
|
||||
@ -57,37 +57,37 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
match system_configuration.current_quota_type:
|
||||
case ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
case ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
case ProviderQuotaType.FREE:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
|
||||
@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -266,9 +266,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
err = self.handle_error(event=event, session=session, message_id=self._message_id)
|
||||
session.commit()
|
||||
yield self.error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
@ -288,10 +287,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
answer=output_moderation_answer
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
session.commit()
|
||||
message_end_resp = self._message_end_to_stream_response()
|
||||
yield message_end_resp
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
|
||||
@ -40,41 +40,44 @@ def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, Upl
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
if message_file.url.startswith(("http://", "https://")):
|
||||
match message_file.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
else:
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0]
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
case FileTransferMethod.TOOL_FILE if message_file.url:
|
||||
if message_file.url.startswith(("http://", "https://")):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
else:
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0]
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
|
||||
extension = ".bin"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
case FileTransferMethod.TOOL_FILE | FileTransferMethod.DATASOURCE_FILE:
|
||||
pass
|
||||
|
||||
transfer_method_value = message_file.transfer_method.value
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
|
||||
@ -187,15 +187,16 @@ def build_parameter_schema(
|
||||
|
||||
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> ToolArgumentsDict:
|
||||
"""Prepare arguments based on app mode"""
|
||||
if app.mode == AppMode.WORKFLOW:
|
||||
return {"inputs": arguments}
|
||||
elif app.mode == AppMode.COMPLETION:
|
||||
return {"query": "", "inputs": arguments}
|
||||
else:
|
||||
# Chat modes - create a copy to avoid modifying original dict
|
||||
args_copy = arguments.copy()
|
||||
query = args_copy.pop("query", "")
|
||||
return {"query": query, "inputs": args_copy}
|
||||
match app.mode:
|
||||
case AppMode.WORKFLOW:
|
||||
return {"inputs": arguments}
|
||||
case AppMode.COMPLETION:
|
||||
return {"query": "", "inputs": arguments}
|
||||
case _:
|
||||
# Chat modes - create a copy to avoid modifying original dict
|
||||
args_copy = arguments.copy()
|
||||
query = args_copy.pop("query", "")
|
||||
return {"query": query, "inputs": args_copy}
|
||||
|
||||
|
||||
def extract_answer_from_response(app: App, response: Any) -> str:
|
||||
@ -229,17 +230,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str:
|
||||
|
||||
def process_mapping_response(app: App, response: Mapping) -> str:
|
||||
"""Process mapping response based on app mode"""
|
||||
if app.mode in {
|
||||
AppMode.ADVANCED_CHAT,
|
||||
AppMode.COMPLETION,
|
||||
AppMode.CHAT,
|
||||
AppMode.AGENT_CHAT,
|
||||
}:
|
||||
return response.get("answer", "")
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode: " + str(app.mode))
|
||||
match app.mode:
|
||||
case AppMode.ADVANCED_CHAT | AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
|
||||
return response.get("answer", "")
|
||||
case AppMode.WORKFLOW:
|
||||
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
case _:
|
||||
raise ValueError("Invalid app mode: " + str(app.mode))
|
||||
|
||||
|
||||
def convert_input_form_to_parameters(
|
||||
|
||||
@ -72,17 +72,18 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
|
||||
conversation_id = conversation_id or ""
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}:
|
||||
if not query:
|
||||
raise ValueError("missing query")
|
||||
match app.mode:
|
||||
case AppMode.ADVANCED_CHAT | AppMode.AGENT_CHAT | AppMode.CHAT:
|
||||
if not query:
|
||||
raise ValueError("missing query")
|
||||
|
||||
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
return cls.invoke_workflow_app(app, user, stream, inputs, files)
|
||||
elif app.mode == AppMode.COMPLETION:
|
||||
return cls.invoke_completion_app(app, user, stream, inputs, files)
|
||||
|
||||
raise ValueError("unexpected app type")
|
||||
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
|
||||
case AppMode.WORKFLOW:
|
||||
return cls.invoke_workflow_app(app, user, stream, inputs, files)
|
||||
case AppMode.COMPLETION:
|
||||
return cls.invoke_completion_app(app, user, stream, inputs, files)
|
||||
case _:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
@classmethod
|
||||
def invoke_chat_app(
|
||||
@ -98,60 +99,61 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke chat app
|
||||
"""
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
match app.mode:
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=db.engine,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=stream,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
case AppMode.AGENT_CHAT:
|
||||
return AgentChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
case AppMode.CHAT:
|
||||
return ChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
case _:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=db.engine,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=stream,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
elif app.mode == AppMode.AGENT_CHAT:
|
||||
return AgentChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
elif app.mode == AppMode.CHAT:
|
||||
return ChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"files": files,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
@classmethod
|
||||
def invoke_workflow_app(
|
||||
cls,
|
||||
|
||||
@ -961,36 +961,37 @@ class ProviderManager:
|
||||
raise ValueError("quota_used is None")
|
||||
if provider_record.quota_limit is None:
|
||||
raise ValueError("quota_limit is None")
|
||||
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=trail_pool.quota_used,
|
||||
quota_limit=trail_pool.quota_limit,
|
||||
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
match provider_quota.quota_type:
|
||||
case ProviderQuotaType.TRIAL if trail_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=trail_pool.quota_used,
|
||||
quota_limit=trail_pool.quota_limit,
|
||||
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=paid_pool.quota_used,
|
||||
quota_limit=paid_pool.quota_limit,
|
||||
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
case ProviderQuotaType.PAID if paid_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=paid_pool.quota_used,
|
||||
quota_limit=paid_pool.quota_limit,
|
||||
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
else:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=provider_record.quota_used,
|
||||
quota_limit=provider_record.quota_limit,
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used
|
||||
or provider_record.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
case _:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=provider_record.quota_used,
|
||||
quota_limit=provider_record.quota_limit,
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used
|
||||
or provider_record.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
quota_configurations.append(quota_configuration)
|
||||
|
||||
|
||||
@ -37,11 +37,12 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self.analyticdb_vector._create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
self.analyticdb_vector.add_texts(documents, embeddings)
|
||||
return []
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self.analyticdb_vector.text_exists(id)
|
||||
|
||||
@ -123,7 +123,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
else:
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
def create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
@ -74,7 +75,7 @@ class AnalyticdbVectorBySql:
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
def _get_cursor(self) -> Iterator[Any]:
|
||||
assert self.pool is not None, "Connection pool is not initialized"
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
@ -130,7 +131,7 @@ class AnalyticdbVectorBySql:
|
||||
)
|
||||
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
def create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import chromadb
|
||||
from chromadb import QueryResult, Settings
|
||||
from chromadb import QueryResult, Settings # pyright: ignore[reportPrivateImportUsage]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
@ -106,14 +106,15 @@ class ChromaVector(BaseVector):
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
results: QueryResult
|
||||
if document_ids_filter:
|
||||
results: QueryResult = collection.query(
|
||||
results = collection.query(
|
||||
query_embeddings=query_vector,
|
||||
n_results=kwargs.get("top_k", 4),
|
||||
where={"document_id": {"$in": document_ids_filter}}, # type: ignore
|
||||
)
|
||||
else:
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
|
||||
results = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
# Check if results contain data
|
||||
@ -165,8 +166,8 @@ class ChromaVectorFactory(AbstractVectorFactory):
|
||||
config=ChromaConfig(
|
||||
host=dify_config.CHROMA_HOST or "",
|
||||
port=dify_config.CHROMA_PORT,
|
||||
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
|
||||
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
|
||||
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, # pyright: ignore[reportPrivateImportUsage]
|
||||
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, # pyright: ignore[reportPrivateImportUsage]
|
||||
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
|
||||
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
|
||||
),
|
||||
|
||||
@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Float, create_engine, insert, select, text
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
|
||||
@ -55,9 +55,8 @@ class PGVectoRS(BaseVector):
|
||||
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
)
|
||||
self._client = create_engine(self._url)
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
|
||||
session.commit()
|
||||
self._fields: list[str] = []
|
||||
|
||||
class _Table(CollectionORM):
|
||||
@ -88,7 +87,7 @@ class PGVectoRS(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
create_statement = sql_text(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
||||
id UUID PRIMARY KEY,
|
||||
@ -111,12 +110,11 @@ class PGVectoRS(BaseVector):
|
||||
$$);
|
||||
""")
|
||||
session.execute(index_statement)
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
pks = []
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
for document, embedding in zip(documents, embeddings):
|
||||
pk = uuid4()
|
||||
session.execute(
|
||||
@ -128,7 +126,6 @@ class PGVectoRS(BaseVector):
|
||||
),
|
||||
)
|
||||
pks.append(pk)
|
||||
session.commit()
|
||||
|
||||
return pks
|
||||
|
||||
@ -145,10 +142,9 @@ class PGVectoRS(BaseVector):
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {"ids": ids})
|
||||
session.commit()
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
with Session(self._client) as session:
|
||||
@ -159,15 +155,13 @@ class PGVectoRS(BaseVector):
|
||||
if result:
|
||||
ids = [item[0] for item in result]
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {"ids": ids})
|
||||
session.commit()
|
||||
|
||||
def delete(self):
|
||||
with Session(self._client) as session:
|
||||
with sessionmaker(bind=self._client).begin() as session:
|
||||
session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}"))
|
||||
session.commit()
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self._client) as session:
|
||||
|
||||
@ -3,7 +3,7 @@ import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import qdrant_client
|
||||
from flask import current_app
|
||||
@ -32,7 +32,6 @@ from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
from qdrant_client.conversions import common_types
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
@ -180,7 +179,7 @@ class QdrantVector(BaseVector):
|
||||
for batch_ids, points in self._generate_rest_batches(
|
||||
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
|
||||
):
|
||||
self._client.upsert(collection_name=self._collection_name, points=points)
|
||||
self._client.upsert(collection_name=self._collection_name, points=cast("common_types.Points", points))
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
return added_ids
|
||||
@ -472,7 +471,7 @@ class QdrantVector(BaseVector):
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client._load()
|
||||
self._client._load() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
|
||||
@ -26,7 +26,7 @@ from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Base = declarative_base() # type: Any
|
||||
Base: Any = declarative_base()
|
||||
|
||||
|
||||
class RelytConfig(BaseModel):
|
||||
|
||||
@ -19,12 +19,15 @@ class UnstructuredWordExtractor(BaseExtractor):
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.__version__ import __version__ as __unstructured_version__
|
||||
from unstructured.file_utils.filetype import FileType, detect_filetype
|
||||
from unstructured.file_utils.filetype import ( # pyright: ignore[reportPrivateImportUsage]
|
||||
FileType,
|
||||
detect_filetype,
|
||||
)
|
||||
|
||||
unstructured_version = tuple(int(x) for x in __unstructured_version__.split("."))
|
||||
# check the file extension
|
||||
try:
|
||||
import magic # noqa: F401
|
||||
import magic # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||
|
||||
is_doc = detect_filetype(self._file_path) == FileType.DOC
|
||||
except ImportError:
|
||||
|
||||
@ -15,7 +15,7 @@ from graphon.model_runtime.entities.message_entities import PromptMessage, Promp
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@ -884,7 +884,7 @@ class DatasetRetrieval:
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
return
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# Collect all document_ids and batch fetch DatasetDocuments
|
||||
document_ids = {
|
||||
doc.metadata["document_id"]
|
||||
@ -975,7 +975,6 @@ class DatasetRetrieval:
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
|
||||
|
||||
@ -75,22 +75,27 @@ class ToolProviderApiEntity(BaseModel):
|
||||
parameter.pop("input_schema", None)
|
||||
# -------------
|
||||
optional_fields = self.optional_field("server_url", self.server_url)
|
||||
if self.type == ToolProviderType.MCP:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
optional_fields.update(
|
||||
self.optional_field(
|
||||
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
|
||||
match self.type:
|
||||
case ToolProviderType.MCP:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
optional_fields.update(
|
||||
self.optional_field(
|
||||
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
|
||||
)
|
||||
)
|
||||
)
|
||||
optional_fields.update(
|
||||
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
|
||||
)
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
elif self.type == ToolProviderType.WORKFLOW:
|
||||
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
|
||||
optional_fields.update(
|
||||
self.optional_field(
|
||||
"authentication", self.authentication.model_dump() if self.authentication else None
|
||||
)
|
||||
)
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
case ToolProviderType.WORKFLOW:
|
||||
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
|
||||
case _:
|
||||
pass
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
|
||||
@ -11,6 +11,7 @@ from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
@ -166,13 +167,7 @@ class ToolFileManager:
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == id).limit(1))
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
@ -190,13 +185,7 @@ class ToolFileManager:
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
message_file: MessageFile | None = (
|
||||
session.query(MessageFile)
|
||||
.where(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
message_file: MessageFile | None = session.scalar(select(MessageFile).where(MessageFile.id == id).limit(1))
|
||||
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
@ -210,13 +199,7 @@ class ToolFileManager:
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
@ -234,13 +217,7 @@ class ToolFileManager:
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1))
|
||||
|
||||
if not tool_file:
|
||||
return None, None
|
||||
|
||||
@ -205,16 +205,160 @@ class ToolManager:
|
||||
|
||||
:return: the tool
|
||||
"""
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
# check if the builtin tool need credentials
|
||||
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||
match provider_type:
|
||||
case ToolProviderType.BUILT_IN:
|
||||
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||
|
||||
builtin_tool = provider_controller.get_tool(tool_name)
|
||||
if not builtin_tool:
|
||||
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
|
||||
builtin_tool = provider_controller.get_tool(tool_name)
|
||||
if not builtin_tool:
|
||||
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
|
||||
|
||||
if not provider_controller.need_credentials:
|
||||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
builtin_provider = None
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
provider_id_entity = ToolProviderID(provider_id)
|
||||
if is_valid_uuid(credential_id):
|
||||
try:
|
||||
builtin_provider_stmt = select(BuiltinToolProvider).where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
builtin_provider = db.session.scalar(builtin_provider_stmt)
|
||||
except Exception as e:
|
||||
builtin_provider = None
|
||||
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
|
||||
|
||||
if builtin_provider is None:
|
||||
with Session(db.engine) as session:
|
||||
builtin_provider = session.scalar(
|
||||
sa.select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||
else:
|
||||
builtin_provider = db.session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
credential_id=builtin_provider.id,
|
||||
provider=provider_id,
|
||||
credential_type=PluginCredentialType.TOOL,
|
||||
check_existence=False,
|
||||
)
|
||||
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
|
||||
],
|
||||
cache=ToolProviderCredentialsCache(
|
||||
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||
),
|
||||
)
|
||||
|
||||
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
|
||||
|
||||
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||
# TODO: circular import
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
tool_provider = ToolProviderID(provider_id)
|
||||
provider_name = tool_provider.provider_name
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=builtin_provider.user_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
# update the credentials
|
||||
builtin_provider.encrypted_credentials = json.dumps(
|
||||
encrypter.encrypt(refreshed_credentials.credentials)
|
||||
)
|
||||
builtin_provider.expires_at = refreshed_credentials.expires_at
|
||||
db.session.commit()
|
||||
decrypted_credentials = refreshed_credentials.credentials
|
||||
cache.delete()
|
||||
|
||||
if not provider_controller.need_credentials:
|
||||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=builtin_provider.credential_type,
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
|
||||
case ToolProviderType.API:
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=api_provider,
|
||||
)
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(encrypter.decrypt(credentials)),
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
case ToolProviderType.WORKFLOW:
|
||||
workflow_provider_stmt = select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
workflow_provider = session.scalar(workflow_provider_stmt)
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
|
||||
if controller_tools is None or len(controller_tools) == 0:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
@ -223,177 +367,28 @@ class ToolManager:
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
builtin_provider = None
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
provider_id_entity = ToolProviderID(provider_id)
|
||||
# get specific credentials
|
||||
if is_valid_uuid(credential_id):
|
||||
try:
|
||||
builtin_provider_stmt = select(BuiltinToolProvider).where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
builtin_provider = db.session.scalar(builtin_provider_stmt)
|
||||
except Exception as e:
|
||||
builtin_provider = None
|
||||
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
|
||||
# if the provider has been deleted, raise an error
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
|
||||
|
||||
# fallback to the default provider
|
||||
if builtin_provider is None:
|
||||
# use the default provider
|
||||
with Session(db.engine) as session:
|
||||
builtin_provider = session.scalar(
|
||||
sa.select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||
else:
|
||||
builtin_provider = db.session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
|
||||
# check if the credential is allowed to be used
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(
|
||||
credential_id=builtin_provider.id,
|
||||
provider=provider_id,
|
||||
credential_type=PluginCredentialType.TOOL,
|
||||
check_existence=False,
|
||||
)
|
||||
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
|
||||
],
|
||||
cache=ToolProviderCredentialsCache(
|
||||
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||
),
|
||||
)
|
||||
|
||||
# decrypt the credentials
|
||||
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
|
||||
|
||||
# check if the credentials is expired
|
||||
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||
# TODO: circular import
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
# refresh the credentials
|
||||
tool_provider = ToolProviderID(provider_id)
|
||||
provider_name = tool_provider.provider_name
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
# refresh the credentials
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=builtin_provider.user_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
# update the credentials
|
||||
builtin_provider.encrypted_credentials = json.dumps(
|
||||
encrypter.encrypt(refreshed_credentials.credentials)
|
||||
)
|
||||
builtin_provider.expires_at = refreshed_credentials.expires_at
|
||||
db.session.commit()
|
||||
decrypted_credentials = refreshed_credentials.credentials
|
||||
cache.delete()
|
||||
|
||||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=builtin_provider.credential_type,
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
|
||||
elif provider_type == ToolProviderType.API:
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=api_provider,
|
||||
)
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(encrypter.decrypt(credentials)),
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
workflow_provider_stmt = select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
workflow_provider = session.scalar(workflow_provider_stmt)
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
|
||||
if controller_tools is None or len(controller_tools) == 0:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
elif provider_type == ToolProviderType.APP:
|
||||
raise NotImplementedError("app provider not implemented")
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
|
||||
runtime = getattr(plugin_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return plugin_tool
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
|
||||
runtime = getattr(mcp_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return mcp_tool
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
case ToolProviderType.APP:
|
||||
raise NotImplementedError("app provider not implemented")
|
||||
case ToolProviderType.PLUGIN:
|
||||
plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
|
||||
runtime = getattr(plugin_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return plugin_tool
|
||||
case ToolProviderType.MCP:
|
||||
mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
|
||||
runtime = getattr(mcp_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return mcp_tool
|
||||
case ToolProviderType.DATASET_RETRIEVAL:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
case _:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type} not found")
|
||||
|
||||
@classmethod
|
||||
def get_agent_tool_runtime(
|
||||
@ -1027,31 +1022,31 @@ class ToolManager:
|
||||
:param provider_id: the id of the provider
|
||||
:return:
|
||||
"""
|
||||
provider_type = provider_type
|
||||
provider_id = provider_id
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
match provider_type:
|
||||
case ToolProviderType.BUILT_IN:
|
||||
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
return cls.generate_builtin_tool_icon_url(provider_id)
|
||||
case ToolProviderType.API:
|
||||
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
|
||||
case ToolProviderType.WORKFLOW:
|
||||
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
||||
case ToolProviderType.PLUGIN:
|
||||
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
return cls.generate_builtin_tool_icon_url(provider_id)
|
||||
elif provider_type == ToolProviderType.API:
|
||||
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
raise ValueError(f"plugin provider {provider_id} not found")
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
case ToolProviderType.MCP:
|
||||
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
||||
case ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
case _:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_parameters_type(
|
||||
|
||||
@ -7,14 +7,13 @@ from sqlalchemy import select
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
from typing import NotRequired, TypedDict, cast
|
||||
from typing import cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.entities import DocumentContext, RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
@ -17,18 +16,6 @@ from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
|
||||
class DefaultRetrievalModelDict(TypedDict):
|
||||
search_method: RetrievalMethod
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModelDict
|
||||
reranking_mode: NotRequired[str]
|
||||
weights: NotRequired[WeightsDict | None]
|
||||
score_threshold: NotRequired[float]
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
|
||||
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
|
||||
@ -19,5 +19,18 @@ def remove_leading_symbols(text: str) -> str:
|
||||
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
|
||||
pattern = re.compile(
|
||||
r"""
|
||||
^
|
||||
(?:
|
||||
[\u2000-\u2025] # General Punctuation: spaces, quotes, dashes
|
||||
| [\u2027-\u206F] # General Punctuation: ellipsis, underscores, etc.
|
||||
| [\u2E00-\u2E7F] # Supplemental Punctuation: medieval, ancient marks
|
||||
| [\u3000-\u300F] # CJK Punctuation: 、。〃「」『》』 (excludes 【】)
|
||||
| [\u3012-\u303F] # CJK Punctuation: 〖〗〔〕〘〙〚〛〜 etc.
|
||||
| ["#$%&'()*+,./:;<=>?@^_`~] # ASCII punctuation (excludes []【】)
|
||||
)+
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
return re.sub(pattern, "", text)
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Mapping
|
||||
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from pydantic import Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
@ -96,10 +97,10 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
:param app: the app
|
||||
:return: the tool
|
||||
"""
|
||||
workflow: Workflow | None = (
|
||||
session.query(Workflow)
|
||||
workflow: Workflow | None = session.scalar(
|
||||
select(Workflow)
|
||||
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
@ -217,13 +218,13 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
return self.tools
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
db_provider: WorkflowToolProvider | None = (
|
||||
session.query(WorkflowToolProvider)
|
||||
db_provider: WorkflowToolProvider | None = session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == self.provider_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not db_provider:
|
||||
|
||||
@ -305,14 +305,15 @@ class WorkflowTool(Tool):
|
||||
"transfer_method": file.transfer_method.value,
|
||||
"type": file.type.value,
|
||||
}
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = resolve_file_record_id(file.reference)
|
||||
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = resolve_file_record_id(file.reference)
|
||||
elif file.transfer_method == FileTransferMethod.DATASOURCE_FILE:
|
||||
file_dict["datasource_file_id"] = resolve_file_record_id(file.reference)
|
||||
elif file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
file_dict["url"] = file.generate_url()
|
||||
match file.transfer_method:
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = resolve_file_record_id(file.reference)
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = resolve_file_record_id(file.reference)
|
||||
case FileTransferMethod.DATASOURCE_FILE:
|
||||
file_dict["datasource_file_id"] = resolve_file_record_id(file.reference)
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
file_dict["url"] = file.generate_url()
|
||||
|
||||
files.append(file_dict)
|
||||
except Exception:
|
||||
@ -357,8 +358,11 @@ class WorkflowTool(Tool):
|
||||
def _update_file_mapping(self, file_dict: dict):
|
||||
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
|
||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||
if transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_id
|
||||
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = file_id
|
||||
match transfer_method:
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_id
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = file_id
|
||||
case FileTransferMethod.REMOTE_URL | FileTransferMethod.DATASOURCE_FILE:
|
||||
pass
|
||||
return file_dict
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import importlib
|
||||
import pkgutil
|
||||
from collections.abc import Callable, Iterator, Mapping, MutableMapping
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, cast, final, override
|
||||
|
||||
@ -67,6 +68,31 @@ _START_NODE_TYPES: frozenset[NodeType] = frozenset(
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DifyGraphInitContext:
|
||||
"""Explicit graph-init values owned by the workflow layer.
|
||||
|
||||
Dify is gradually removing direct `GraphInitParams` construction from its
|
||||
production call sites. Keep the translation here until `graphon` exposes an
|
||||
equivalent explicit API.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
graph_config: Mapping[str, Any]
|
||||
run_context: Mapping[str, Any]
|
||||
call_depth: int
|
||||
|
||||
def to_graph_init_params(self) -> "GraphInitParams":
|
||||
from graphon.entities import GraphInitParams
|
||||
|
||||
return GraphInitParams(
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
run_context=self.run_context,
|
||||
call_depth=self.call_depth,
|
||||
)
|
||||
|
||||
|
||||
def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None:
|
||||
package = importlib.import_module(package_name)
|
||||
for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
|
||||
@ -237,6 +263,19 @@ class DifyNodeFactory(NodeFactory):
|
||||
Default implementation of NodeFactory that resolves node classes from the live registry.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_graph_init_context(
|
||||
cls,
|
||||
*,
|
||||
graph_init_context: DifyGraphInitContext,
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> "DifyNodeFactory":
|
||||
"""Bridge Dify's explicit init context into the current `graphon` API."""
|
||||
return cls(
|
||||
graph_init_params=graph_init_context.to_graph_init_params(),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_init_params: "GraphInitParams",
|
||||
|
||||
@ -29,7 +29,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
def post_init(self) -> None:
|
||||
from core.workflow.node_runtime import DifyFileReferenceFactory
|
||||
|
||||
self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context)
|
||||
self._file_reference_factory = DifyFileReferenceFactory(self.run_context)
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@ -155,24 +155,25 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
outputs[param_name] = raw_data
|
||||
continue
|
||||
|
||||
if param_type == SegmentType.FILE:
|
||||
# Get File object (already processed by webhook controller)
|
||||
files = webhook_data.get("files", {})
|
||||
if files and isinstance(files, dict):
|
||||
file = files.get(param_name)
|
||||
if file and isinstance(file, dict):
|
||||
file_var = self.generate_file_var(param_name, file)
|
||||
if file_var:
|
||||
outputs[param_name] = file_var
|
||||
match param_type:
|
||||
case SegmentType.FILE:
|
||||
# Get File object (already processed by webhook controller)
|
||||
files = webhook_data.get("files", {})
|
||||
if files and isinstance(files, dict):
|
||||
file = files.get(param_name)
|
||||
if file and isinstance(file, dict):
|
||||
file_var = self.generate_file_var(param_name, file)
|
||||
if file_var:
|
||||
outputs[param_name] = file_var
|
||||
else:
|
||||
outputs[param_name] = files
|
||||
else:
|
||||
outputs[param_name] = files
|
||||
else:
|
||||
outputs[param_name] = files
|
||||
else:
|
||||
outputs[param_name] = files
|
||||
else:
|
||||
# Get regular body parameter
|
||||
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
|
||||
case _:
|
||||
# Get regular body parameter
|
||||
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
|
||||
|
||||
# Include raw webhook data for debugging/advanced use
|
||||
outputs["_webhook_raw"] = webhook_data
|
||||
|
||||
@ -24,7 +24,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class
|
||||
from core.workflow.node_factory import (
|
||||
DifyGraphInitContext,
|
||||
DifyNodeFactory,
|
||||
is_start_node_type,
|
||||
resolve_workflow_node_class,
|
||||
)
|
||||
from core.workflow.system_variables import (
|
||||
default_system_variables,
|
||||
get_node_creation_preload_selectors,
|
||||
@ -251,17 +256,18 @@ class WorkflowEntry:
|
||||
node_version = str(node_config_data.version)
|
||||
node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version)
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
# init graph context and runtime state
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow.id,
|
||||
graph_config=workflow.graph_dict,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
run_context=run_context,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
@ -313,8 +319,8 @@ class WorkflowEntry:
|
||||
)
|
||||
|
||||
# init workflow run state
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node = node_factory.create_node(node_config)
|
||||
@ -409,17 +415,18 @@ class WorkflowEntry:
|
||||
variable_pool = VariablePool()
|
||||
add_variables_to_pool(variable_pool, default_system_variables())
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
# init graph context and runtime state
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=tenant_id,
|
||||
app_id="",
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id="",
|
||||
graph_config=graph_dict,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=tenant_id,
|
||||
app_id="",
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
run_context=run_context,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
@ -430,8 +437,8 @@ class WorkflowEntry:
|
||||
|
||||
# init workflow run state
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node = node_factory.create_node(node_config)
|
||||
|
||||
@ -68,46 +68,49 @@ class EnterpriseMetricHandler:
|
||||
|
||||
# Route to appropriate handler based on case
|
||||
case = envelope.case
|
||||
if case == TelemetryCase.APP_CREATED:
|
||||
self._on_app_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
|
||||
elif case == TelemetryCase.APP_UPDATED:
|
||||
self._on_app_updated(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
|
||||
elif case == TelemetryCase.APP_DELETED:
|
||||
self._on_app_deleted(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
|
||||
elif case == TelemetryCase.FEEDBACK_CREATED:
|
||||
self._on_feedback_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
|
||||
elif case == TelemetryCase.MESSAGE_RUN:
|
||||
self._on_message_run(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
|
||||
elif case == TelemetryCase.TOOL_EXECUTION:
|
||||
self._on_tool_execution(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
|
||||
elif case == TelemetryCase.MODERATION_CHECK:
|
||||
self._on_moderation_check(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
|
||||
elif case == TelemetryCase.SUGGESTED_QUESTION:
|
||||
self._on_suggested_question(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
|
||||
elif case == TelemetryCase.DATASET_RETRIEVAL:
|
||||
self._on_dataset_retrieval(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
|
||||
elif case == TelemetryCase.GENERATE_NAME:
|
||||
self._on_generate_name(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
|
||||
elif case == TelemetryCase.PROMPT_GENERATION:
|
||||
self._on_prompt_generation(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
|
||||
case,
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
)
|
||||
match case:
|
||||
case TelemetryCase.APP_CREATED:
|
||||
self._on_app_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
|
||||
case TelemetryCase.APP_UPDATED:
|
||||
self._on_app_updated(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
|
||||
case TelemetryCase.APP_DELETED:
|
||||
self._on_app_deleted(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
|
||||
case TelemetryCase.FEEDBACK_CREATED:
|
||||
self._on_feedback_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
|
||||
case TelemetryCase.MESSAGE_RUN:
|
||||
self._on_message_run(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
|
||||
case TelemetryCase.TOOL_EXECUTION:
|
||||
self._on_tool_execution(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
|
||||
case TelemetryCase.MODERATION_CHECK:
|
||||
self._on_moderation_check(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
|
||||
case TelemetryCase.SUGGESTED_QUESTION:
|
||||
self._on_suggested_question(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
|
||||
case TelemetryCase.DATASET_RETRIEVAL:
|
||||
self._on_dataset_retrieval(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
|
||||
case TelemetryCase.GENERATE_NAME:
|
||||
self._on_generate_name(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
|
||||
case TelemetryCase.PROMPT_GENERATION:
|
||||
self._on_prompt_generation(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
|
||||
case TelemetryCase.WORKFLOW_RUN | TelemetryCase.NODE_EXECUTION | TelemetryCase.DRAFT_NODE_EXECUTION:
|
||||
pass
|
||||
case _:
|
||||
logger.warning(
|
||||
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
|
||||
case,
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
)
|
||||
|
||||
def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool:
|
||||
"""Check if this event has already been processed.
|
||||
|
||||
@ -7,10 +7,12 @@ from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import redis
|
||||
from redis import RedisError
|
||||
from redis.backoff import ExponentialWithJitterBackoff # type: ignore
|
||||
from redis.cache import CacheConfig
|
||||
from redis.client import PubSub
|
||||
from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
from redis.retry import Retry
|
||||
from redis.sentinel import Sentinel
|
||||
|
||||
from configs import dify_config
|
||||
@ -158,8 +160,41 @@ def _get_cache_configuration() -> CacheConfig | None:
|
||||
return CacheConfig()
|
||||
|
||||
|
||||
def _get_retry_policy() -> Retry:
|
||||
"""Build the shared retry policy for Redis connections."""
|
||||
return Retry(
|
||||
backoff=ExponentialWithJitterBackoff(
|
||||
base=dify_config.REDIS_RETRY_BACKOFF_BASE,
|
||||
cap=dify_config.REDIS_RETRY_BACKOFF_CAP,
|
||||
),
|
||||
retries=dify_config.REDIS_RETRY_RETRIES,
|
||||
)
|
||||
|
||||
|
||||
def _get_connection_health_params() -> dict[str, Any]:
|
||||
"""Get connection health and retry parameters for standalone and Sentinel Redis clients."""
|
||||
return {
|
||||
"retry": _get_retry_policy(),
|
||||
"socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT,
|
||||
"socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
|
||||
"health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL,
|
||||
}
|
||||
|
||||
|
||||
def _get_cluster_connection_health_params() -> dict[str, Any]:
|
||||
"""Get retry and timeout parameters for Redis Cluster clients.
|
||||
|
||||
RedisCluster does not support ``health_check_interval`` as a constructor
|
||||
keyword (it is silently stripped by ``cleanup_kwargs``), so it is excluded
|
||||
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
|
||||
are passed through.
|
||||
"""
|
||||
params = _get_connection_health_params()
|
||||
return {k: v for k, v in params.items() if k != "health_check_interval"}
|
||||
|
||||
|
||||
def _get_base_redis_params() -> dict[str, Any]:
|
||||
"""Get base Redis connection parameters."""
|
||||
"""Get base Redis connection parameters including retry and health policy."""
|
||||
return {
|
||||
"username": dify_config.REDIS_USERNAME,
|
||||
"password": dify_config.REDIS_PASSWORD or None,
|
||||
@ -169,6 +204,7 @@ def _get_base_redis_params() -> dict[str, Any]:
|
||||
"decode_responses": False,
|
||||
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
"cache_config": _get_cache_configuration(),
|
||||
**_get_connection_health_params(),
|
||||
}
|
||||
|
||||
|
||||
@ -215,6 +251,7 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
|
||||
"password": dify_config.REDIS_CLUSTERS_PASSWORD,
|
||||
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
"cache_config": _get_cache_configuration(),
|
||||
**_get_cluster_connection_health_params(),
|
||||
}
|
||||
if dify_config.REDIS_MAX_CONNECTIONS:
|
||||
cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
||||
@ -226,7 +263,8 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
|
||||
"""Create standalone Redis client."""
|
||||
connection_class, ssl_kwargs = _get_ssl_configuration()
|
||||
|
||||
redis_params.update(
|
||||
params = {**redis_params}
|
||||
params.update(
|
||||
{
|
||||
"host": dify_config.REDIS_HOST,
|
||||
"port": dify_config.REDIS_PORT,
|
||||
@ -235,28 +273,31 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
|
||||
)
|
||||
|
||||
if dify_config.REDIS_MAX_CONNECTIONS:
|
||||
redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
||||
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
||||
|
||||
if ssl_kwargs:
|
||||
redis_params.update(ssl_kwargs)
|
||||
params.update(ssl_kwargs)
|
||||
|
||||
pool = redis.ConnectionPool(**redis_params)
|
||||
pool = redis.ConnectionPool(**params)
|
||||
client: redis.Redis = redis.Redis(connection_pool=pool)
|
||||
return client
|
||||
|
||||
|
||||
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster:
|
||||
max_conns = dify_config.REDIS_MAX_CONNECTIONS
|
||||
if use_clusters:
|
||||
if max_conns:
|
||||
return RedisCluster.from_url(pubsub_url, max_connections=max_conns)
|
||||
else:
|
||||
return RedisCluster.from_url(pubsub_url)
|
||||
|
||||
if use_clusters:
|
||||
health_params = _get_cluster_connection_health_params()
|
||||
kwargs: dict[str, Any] = {**health_params}
|
||||
if max_conns:
|
||||
kwargs["max_connections"] = max_conns
|
||||
return RedisCluster.from_url(pubsub_url, **kwargs)
|
||||
|
||||
health_params = _get_connection_health_params()
|
||||
kwargs = {**health_params}
|
||||
if max_conns:
|
||||
return redis.Redis.from_url(pubsub_url, max_connections=max_conns)
|
||||
else:
|
||||
return redis.Redis.from_url(pubsub_url)
|
||||
kwargs["max_connections"] = max_conns
|
||||
return redis.Redis.from_url(pubsub_url, **kwargs)
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
|
||||
@ -813,56 +813,32 @@ class AppModelConfig(TypeBase):
|
||||
"file_upload": self.file_upload_dict,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dump_optional(value: Any) -> str | None:
|
||||
return json.dumps(value) if value else None
|
||||
|
||||
def from_model_config_dict(self, model_config: AppModelConfigDict):
|
||||
self.opening_statement = model_config.get("opening_statement")
|
||||
self.suggested_questions = (
|
||||
json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None
|
||||
)
|
||||
self.suggested_questions_after_answer = (
|
||||
json.dumps(model_config.get("suggested_questions_after_answer"))
|
||||
if model_config.get("suggested_questions_after_answer")
|
||||
else None
|
||||
)
|
||||
self.speech_to_text = (
|
||||
json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None
|
||||
)
|
||||
self.text_to_speech = (
|
||||
json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None
|
||||
)
|
||||
self.more_like_this = (
|
||||
json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None
|
||||
)
|
||||
self.sensitive_word_avoidance = (
|
||||
json.dumps(model_config.get("sensitive_word_avoidance"))
|
||||
if model_config.get("sensitive_word_avoidance")
|
||||
else None
|
||||
)
|
||||
self.external_data_tools = (
|
||||
json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None
|
||||
)
|
||||
self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None
|
||||
self.user_input_form = (
|
||||
json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None
|
||||
self.suggested_questions = self._dump_optional(model_config.get("suggested_questions"))
|
||||
self.suggested_questions_after_answer = self._dump_optional(
|
||||
model_config.get("suggested_questions_after_answer")
|
||||
)
|
||||
self.speech_to_text = self._dump_optional(model_config.get("speech_to_text"))
|
||||
self.text_to_speech = self._dump_optional(model_config.get("text_to_speech"))
|
||||
self.more_like_this = self._dump_optional(model_config.get("more_like_this"))
|
||||
self.sensitive_word_avoidance = self._dump_optional(model_config.get("sensitive_word_avoidance"))
|
||||
self.external_data_tools = self._dump_optional(model_config.get("external_data_tools"))
|
||||
self.model = self._dump_optional(model_config.get("model"))
|
||||
self.user_input_form = self._dump_optional(model_config.get("user_input_form"))
|
||||
self.dataset_query_variable = model_config.get("dataset_query_variable")
|
||||
self.pre_prompt = model_config.get("pre_prompt")
|
||||
self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None
|
||||
self.retriever_resource = (
|
||||
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
|
||||
)
|
||||
self.agent_mode = self._dump_optional(model_config.get("agent_mode"))
|
||||
self.retriever_resource = self._dump_optional(model_config.get("retriever_resource"))
|
||||
self.prompt_type = PromptType(model_config.get("prompt_type", "simple"))
|
||||
self.chat_prompt_config = (
|
||||
json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None
|
||||
)
|
||||
self.completion_prompt_config = (
|
||||
json.dumps(model_config.get("completion_prompt_config"))
|
||||
if model_config.get("completion_prompt_config")
|
||||
else None
|
||||
)
|
||||
self.dataset_configs = (
|
||||
json.dumps(model_config.get("dataset_configs")) if model_config.get("dataset_configs") else None
|
||||
)
|
||||
self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None
|
||||
self.chat_prompt_config = self._dump_optional(model_config.get("chat_prompt_config"))
|
||||
self.completion_prompt_config = self._dump_optional(model_config.get("completion_prompt_config"))
|
||||
self.dataset_configs = self._dump_optional(model_config.get("dataset_configs"))
|
||||
self.file_upload = self._dump_optional(model_config.get("file_upload"))
|
||||
return self
|
||||
|
||||
|
||||
@ -1632,52 +1608,53 @@ class Message(Base):
|
||||
|
||||
files: list[File] = []
|
||||
for message_file in message_files:
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if message_file.upload_file_id is None:
|
||||
raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id")
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping={
|
||||
match message_file.transfer_method:
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
if message_file.upload_file_id is None:
|
||||
raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id")
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping={
|
||||
"id": message_file.id,
|
||||
"type": message_file.type,
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"upload_file_id": message_file.upload_file_id,
|
||||
},
|
||||
tenant_id=current_app.tenant_id,
|
||||
access_controller=_get_file_access_controller(),
|
||||
)
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
if message_file.url is None:
|
||||
raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url")
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping={
|
||||
"id": message_file.id,
|
||||
"type": message_file.type,
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"upload_file_id": message_file.upload_file_id,
|
||||
"url": message_file.url,
|
||||
},
|
||||
tenant_id=current_app.tenant_id,
|
||||
access_controller=_get_file_access_controller(),
|
||||
)
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
if message_file.upload_file_id is None:
|
||||
assert message_file.url is not None
|
||||
message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0]
|
||||
mapping = {
|
||||
"id": message_file.id,
|
||||
"type": message_file.type,
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"upload_file_id": message_file.upload_file_id,
|
||||
},
|
||||
tenant_id=current_app.tenant_id,
|
||||
access_controller=_get_file_access_controller(),
|
||||
)
|
||||
elif message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if message_file.url is None:
|
||||
raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url")
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping={
|
||||
"id": message_file.id,
|
||||
"type": message_file.type,
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"upload_file_id": message_file.upload_file_id,
|
||||
"url": message_file.url,
|
||||
},
|
||||
tenant_id=current_app.tenant_id,
|
||||
access_controller=_get_file_access_controller(),
|
||||
)
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
if message_file.upload_file_id is None:
|
||||
assert message_file.url is not None
|
||||
message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0]
|
||||
mapping = {
|
||||
"id": message_file.id,
|
||||
"type": message_file.type,
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"tool_file_id": message_file.upload_file_id,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=current_app.tenant_id,
|
||||
access_controller=_get_file_access_controller(),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}"
|
||||
)
|
||||
"tool_file_id": message_file.upload_file_id,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=current_app.tenant_id,
|
||||
access_controller=_get_file_access_controller(),
|
||||
)
|
||||
case FileTransferMethod.DATASOURCE_FILE:
|
||||
raise ValueError(
|
||||
f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}"
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
result = cast(
|
||||
|
||||
@ -1625,21 +1625,22 @@ class WorkflowDraftVariable(Base):
|
||||
# Rebuild them through the file factory so tenant ownership, signed URLs,
|
||||
# and storage-backed metadata come from canonical records instead of the
|
||||
# serialized JSON blob.
|
||||
if segment_type == SegmentType.FILE:
|
||||
if isinstance(value, File):
|
||||
return build_segment_with_type(segment_type, value)
|
||||
elif isinstance(value, dict):
|
||||
file = self._rebuild_file_types(value)
|
||||
return build_segment_with_type(segment_type, file)
|
||||
else:
|
||||
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
|
||||
if segment_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(value, list):
|
||||
raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
|
||||
file_list = self._rebuild_file_types(value)
|
||||
return build_segment_with_type(segment_type=segment_type, value=file_list)
|
||||
|
||||
return build_segment_with_type(segment_type=segment_type, value=value)
|
||||
match segment_type:
|
||||
case SegmentType.FILE:
|
||||
if isinstance(value, File):
|
||||
return build_segment_with_type(segment_type, value)
|
||||
elif isinstance(value, dict):
|
||||
file = self._rebuild_file_types(value)
|
||||
return build_segment_with_type(segment_type, file)
|
||||
else:
|
||||
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
|
||||
case SegmentType.ARRAY_FILE:
|
||||
if not isinstance(value, list):
|
||||
raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
|
||||
file_list = self._rebuild_file_types(value)
|
||||
return build_segment_with_type(segment_type=segment_type, value=file_list)
|
||||
case _:
|
||||
return build_segment_with_type(segment_type=segment_type, value=value)
|
||||
|
||||
@staticmethod
|
||||
def rebuild_file_types(value: Any):
|
||||
@ -1672,21 +1673,22 @@ class WorkflowDraftVariable(Base):
|
||||
# Extends `variable_factory.build_segment_with_type` functionality by
|
||||
# reconstructing `FileSegment`` or `ArrayFileSegment`` objects from
|
||||
# their serialized dictionary or list representations, respectively.
|
||||
if segment_type == SegmentType.FILE:
|
||||
if isinstance(value, File):
|
||||
return build_segment_with_type(segment_type, value)
|
||||
elif isinstance(value, dict):
|
||||
file = cls.rebuild_file_types(value)
|
||||
return build_segment_with_type(segment_type, file)
|
||||
else:
|
||||
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
|
||||
if segment_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(value, list):
|
||||
raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
|
||||
file_list = cls.rebuild_file_types(value)
|
||||
return build_segment_with_type(segment_type=segment_type, value=file_list)
|
||||
|
||||
return build_segment_with_type(segment_type=segment_type, value=value)
|
||||
match segment_type:
|
||||
case SegmentType.FILE:
|
||||
if isinstance(value, File):
|
||||
return build_segment_with_type(segment_type, value)
|
||||
elif isinstance(value, dict):
|
||||
file = cls.rebuild_file_types(value)
|
||||
return build_segment_with_type(segment_type, file)
|
||||
else:
|
||||
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
|
||||
case SegmentType.ARRAY_FILE:
|
||||
if not isinstance(value, list):
|
||||
raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
|
||||
file_list = cls.rebuild_file_types(value)
|
||||
return build_segment_with_type(segment_type=segment_type, value=file_list)
|
||||
case _:
|
||||
return build_segment_with_type(segment_type=segment_type, value=value)
|
||||
|
||||
def get_value(self) -> Segment:
|
||||
"""Decode the serialized value into its corresponding `Segment` object.
|
||||
|
||||
@ -40,7 +40,7 @@ dependencies = [
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
|
||||
"litellm==1.83.0", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.40.0",
|
||||
"opentelemetry-distro==0.61b0",
|
||||
"opentelemetry-exporter-otlp==1.40.0",
|
||||
@ -171,7 +171,7 @@ dev = [
|
||||
"sseclient-py>=1.8.0",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pyrefly>=0.59.1",
|
||||
"pyrefly>=0.60.0",
|
||||
]
|
||||
|
||||
############################################################
|
||||
|
||||
@ -4,6 +4,7 @@ import time
|
||||
from collections.abc import Sequence
|
||||
|
||||
import click
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
import app
|
||||
@ -113,11 +114,9 @@ def _delete_batch(
|
||||
try:
|
||||
with session.begin_nested():
|
||||
workflow_run_ids = [run.id for run in workflow_runs]
|
||||
message_data = (
|
||||
session.query(Message.id, Message.conversation_id)
|
||||
.where(Message.workflow_run_id.in_(workflow_run_ids))
|
||||
.all()
|
||||
)
|
||||
message_data = session.execute(
|
||||
select(Message.id, Message.conversation_id).where(Message.workflow_run_id.in_(workflow_run_ids))
|
||||
).all()
|
||||
message_id_list = [msg.id for msg in message_data]
|
||||
conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id})
|
||||
if message_id_list:
|
||||
@ -132,23 +131,19 @@ def _delete_batch(
|
||||
SavedMessage,
|
||||
]
|
||||
for model in message_related_models:
|
||||
session.query(model).where(model.message_id.in_(message_id_list)).delete(synchronize_session=False) # type: ignore
|
||||
session.execute(delete(model).where(model.message_id.in_(message_id_list))) # type: ignore
|
||||
# error: "DeclarativeAttributeIntercept" has no attribute "message_id". But this type is only in lib
|
||||
# and these 6 types all have the message_id field.
|
||||
|
||||
session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
session.execute(delete(Message).where(Message.workflow_run_id.in_(workflow_run_ids)))
|
||||
|
||||
if conversation_id_list:
|
||||
session.query(ConversationVariable).where(
|
||||
ConversationVariable.conversation_id.in_(conversation_id_list)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(ConversationVariable).where(ConversationVariable.conversation_id.in_(conversation_id_list))
|
||||
)
|
||||
|
||||
session.execute(delete(Conversation).where(Conversation.id.in_(conversation_id_list)))
|
||||
|
||||
def _delete_node_executions(active_session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
|
||||
run_ids = [run.id for run in runs]
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
|
||||
@ -3,7 +3,6 @@ import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
@ -19,7 +18,7 @@ from graphon.nodes.question_classifier.entities import QuestionClassifierNodeDat
|
||||
from graphon.nodes.tool.entities import ToolNodeData
|
||||
from packaging import version
|
||||
from packaging.version import parse as parse_version
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -40,6 +39,7 @@ from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, AppMode
|
||||
from models.model import AppModelConfig, AppModelConfigDict, IconType
|
||||
from models.workflow import Workflow
|
||||
from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -53,18 +53,6 @@ DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.6.0"
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
YAML_CONTENT = "yaml-content"
|
||||
YAML_URL = "yaml-url"
|
||||
|
||||
|
||||
class ImportStatus(StrEnum):
|
||||
COMPLETED = "completed"
|
||||
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
|
||||
PENDING = "pending"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class Import(BaseModel):
|
||||
id: str
|
||||
status: ImportStatus
|
||||
@ -75,10 +63,6 @@ class Import(BaseModel):
|
||||
error: str = ""
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
"""Determine import status based on version comparison"""
|
||||
try:
|
||||
|
||||
@ -120,7 +120,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
|
||||
app_ids = [app.id for app in apps]
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
|
||||
messages = (
|
||||
session.query(Message)
|
||||
.where(
|
||||
@ -152,7 +152,6 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
cls._clear_message_related_tables(session, tenant_id, message_ids)
|
||||
session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
@ -161,7 +160,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
)
|
||||
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
|
||||
conversations = (
|
||||
session.query(Conversation)
|
||||
.where(
|
||||
@ -190,7 +189,6 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
session.query(Conversation).where(
|
||||
Conversation.id.in_(conversation_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
@ -294,7 +292,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
break
|
||||
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
|
||||
workflow_app_logs = (
|
||||
session.query(WorkflowAppLog)
|
||||
.where(
|
||||
@ -326,7 +324,6 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.provider_entities import FormType
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
@ -53,13 +53,12 @@ class DatasourceProviderService:
|
||||
"""
|
||||
remove oauth custom client params
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(DatasourceOauthTenantParamConfig).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
def decrypt_datasource_provider_credentials(
|
||||
self,
|
||||
@ -109,7 +108,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
get credential by id
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
if credential_id:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
@ -156,7 +155,6 @@ class DatasourceProviderService:
|
||||
datasource_provider=datasource_provider,
|
||||
)
|
||||
datasource_provider.expires_at = refreshed_credentials.expires_at
|
||||
session.commit()
|
||||
|
||||
return self.decrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
@ -174,7 +172,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
get all datasource credentials by provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_providers = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
@ -224,7 +222,6 @@ class DatasourceProviderService:
|
||||
provider=provider,
|
||||
)
|
||||
real_credentials_list.append(real_credentials)
|
||||
session.commit()
|
||||
|
||||
return real_credentials_list
|
||||
|
||||
@ -234,7 +231,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
update datasource provider name
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
@ -266,7 +263,6 @@ class DatasourceProviderService:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
|
||||
target_provider.name = name
|
||||
session.commit()
|
||||
return
|
||||
|
||||
def set_default_datasource_provider(
|
||||
@ -275,7 +271,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
set default datasource provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get provider
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
@ -300,7 +296,6 @@ class DatasourceProviderService:
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
def setup_oauth_custom_client_params(
|
||||
@ -315,7 +310,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
if client_params is None and enabled is None:
|
||||
return
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
@ -349,7 +344,6 @@ class DatasourceProviderService:
|
||||
|
||||
if enabled is not None:
|
||||
tenant_oauth_client_params.enabled = enabled
|
||||
session.commit()
|
||||
|
||||
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
|
||||
"""
|
||||
@ -488,7 +482,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
update datasource oauth provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
target_provider = (
|
||||
@ -535,7 +529,6 @@ class DatasourceProviderService:
|
||||
target_provider.expires_at = expire_at
|
||||
target_provider.encrypted_credentials = credentials
|
||||
target_provider.avatar_url = avatar_url or target_provider.avatar_url
|
||||
session.commit()
|
||||
|
||||
def add_datasource_oauth_provider(
|
||||
self,
|
||||
@ -550,7 +543,7 @@ class DatasourceProviderService:
|
||||
add datasource oauth provider
|
||||
"""
|
||||
credential_type = CredentialType.OAUTH2
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
|
||||
with redis_client.lock(lock, timeout=60):
|
||||
db_provider_name = name
|
||||
@ -604,7 +597,6 @@ class DatasourceProviderService:
|
||||
expires_at=expire_at,
|
||||
)
|
||||
session.add(datasource_provider)
|
||||
session.commit()
|
||||
|
||||
def add_datasource_api_key_provider(
|
||||
self,
|
||||
@ -623,7 +615,7 @@ class DatasourceProviderService:
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
db_provider_name = name or self.generate_next_datasource_provider_name(
|
||||
@ -670,7 +662,6 @@ class DatasourceProviderService:
|
||||
encrypted_credentials=credentials,
|
||||
)
|
||||
session.add(datasource_provider)
|
||||
session.commit()
|
||||
|
||||
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
|
||||
"""
|
||||
@ -926,7 +917,7 @@ class DatasourceProviderService:
|
||||
update datasource credentials.
|
||||
"""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
|
||||
@ -980,7 +971,6 @@ class DatasourceProviderService:
|
||||
encrypted_credentials[key] = value
|
||||
|
||||
datasource_provider.encrypted_credentials = encrypted_credentials
|
||||
session.commit()
|
||||
|
||||
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
|
||||
"""
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
|
||||
from sqlalchemy import case, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
@ -24,7 +24,7 @@ class EndUserService:
|
||||
when an end-user ID is known.
|
||||
"""
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
return session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
@ -54,7 +54,7 @@ class EndUserService:
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
# Query with ORDER BY to prioritize exact type matches while maintaining backward compatibility
|
||||
# This single query approach is more efficient than separate queries
|
||||
end_user = session.scalar(
|
||||
@ -82,7 +82,6 @@ class EndUserService:
|
||||
user_id,
|
||||
)
|
||||
end_user.type = type
|
||||
session.commit()
|
||||
else:
|
||||
# Create new end user if none exists
|
||||
end_user = EndUser(
|
||||
@ -94,7 +93,6 @@ class EndUserService:
|
||||
external_user_id=user_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
@ -135,7 +133,7 @@ class EndUserService:
|
||||
if not unique_app_ids:
|
||||
return result
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
# Fetch existing end users for all target apps in a single query
|
||||
existing_end_users: list[EndUser] = list(
|
||||
session.scalars(
|
||||
@ -174,7 +172,6 @@ class EndUserService:
|
||||
)
|
||||
|
||||
session.add_all(new_end_users)
|
||||
session.commit()
|
||||
|
||||
for eu in new_end_users:
|
||||
result[eu.app_id] = eu
|
||||
|
||||
21
api/services/entities/dsl_entities.py
Normal file
21
api/services/entities/dsl_entities.py
Normal file
@ -0,0 +1,21 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
YAML_CONTENT = "yaml-content"
|
||||
YAML_URL = "yaml-url"
|
||||
|
||||
|
||||
class ImportStatus(StrEnum):
|
||||
COMPLETED = "completed"
|
||||
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
|
||||
PENDING = "pending"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
@ -1,4 +1,5 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
@ -7,11 +8,11 @@ from models.account import TenantPluginAutoUpgradeStrategy
|
||||
class PluginAutoUpgradeService:
|
||||
@staticmethod
|
||||
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
|
||||
with Session(db.engine) as session:
|
||||
return (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
return session.scalar(
|
||||
select(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -23,11 +24,11 @@ class PluginAutoUpgradeService:
|
||||
exclude_plugins: list[str],
|
||||
include_plugins: list[str],
|
||||
) -> bool:
|
||||
with Session(db.engine) as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
exist_strategy = session.scalar(
|
||||
select(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not exist_strategy:
|
||||
strategy = TenantPluginAutoUpgradeStrategy(
|
||||
@ -46,16 +47,15 @@ class PluginAutoUpgradeService:
|
||||
exist_strategy.exclude_plugins = exclude_plugins
|
||||
exist_strategy.include_plugins = include_plugins
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
|
||||
with Session(db.engine) as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
exist_strategy = session.scalar(
|
||||
select(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not exist_strategy:
|
||||
# create for this tenant
|
||||
@ -83,5 +83,4 @@ class PluginAutoUpgradeService:
|
||||
exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
exist_strategy.exclude_plugins = [plugin_id]
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginPermission
|
||||
@ -7,8 +8,10 @@ from models.account import TenantPluginPermission
|
||||
class PluginPermissionService:
|
||||
@staticmethod
|
||||
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
|
||||
with Session(db.engine) as session:
|
||||
return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
return session.scalar(
|
||||
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def change_permission(
|
||||
@ -16,9 +19,9 @@ class PluginPermissionService:
|
||||
install_permission: TenantPluginPermission.InstallPermission,
|
||||
debug_permission: TenantPluginPermission.DebugPermission,
|
||||
):
|
||||
with Session(db.engine) as session:
|
||||
permission = (
|
||||
session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
permission = session.scalar(
|
||||
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
if not permission:
|
||||
permission = TenantPluginPermission(
|
||||
@ -30,5 +33,4 @@ class PluginPermissionService:
|
||||
permission.install_permission = install_permission
|
||||
permission.debug_permission = debug_permission
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@ -555,7 +555,7 @@ class RagPipelineService:
|
||||
workflow_node_execution.id
|
||||
)
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=pipeline.id,
|
||||
@ -569,7 +569,6 @@ class RagPipelineService:
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
)
|
||||
session.commit()
|
||||
if isinstance(workflow_node_execution_db_model, WorkflowNodeExecutionModel):
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=workflow_node_execution_db_model,
|
||||
@ -1325,7 +1324,7 @@ class RagPipelineService:
|
||||
# Convert node_execution to WorkflowNodeExecution after save
|
||||
workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) # type: ignore
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=pipeline.id,
|
||||
@ -1339,7 +1338,6 @@ class RagPipelineService:
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
)
|
||||
session.commit()
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=workflow_node_execution_db_model,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
|
||||
@ -5,7 +5,6 @@ import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
@ -21,7 +20,7 @@ from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeDat
|
||||
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from graphon.nodes.tool.entities import ToolNodeData
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -38,6 +37,7 @@ from models import Account
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
|
||||
from models.enums import CollectionBindingType, DatasetRuntimeMode
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
IconInfo,
|
||||
KnowledgeConfiguration,
|
||||
@ -54,18 +54,6 @@ DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.1.0"
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
YAML_CONTENT = "yaml-content"
|
||||
YAML_URL = "yaml-url"
|
||||
|
||||
|
||||
class ImportStatus(StrEnum):
|
||||
COMPLETED = "completed"
|
||||
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
|
||||
PENDING = "pending"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class RagPipelineImportInfo(BaseModel):
|
||||
id: str
|
||||
status: ImportStatus
|
||||
@ -76,10 +64,6 @@ class RagPipelineImportInfo(BaseModel):
|
||||
dataset_id: str | None = None
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
"""Determine import status based on version comparison"""
|
||||
try:
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, TypedDict, cast
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, select, tuple_
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
@ -369,7 +369,7 @@ class MessagesCleanService:
|
||||
batch_deleted_messages = 0
|
||||
|
||||
# Step 1: Fetch a batch of messages using cursor
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
fetch_messages_start = time.monotonic()
|
||||
msg_stmt = (
|
||||
select(Message.id, Message.app_id, Message.created_at)
|
||||
@ -477,7 +477,7 @@ class MessagesCleanService:
|
||||
|
||||
# Step 4: Batch delete messages and their relations
|
||||
if not self._dry_run:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
delete_relations_start = time.monotonic()
|
||||
# Delete related records first
|
||||
self._batch_delete_message_relations(session, message_ids_to_delete)
|
||||
@ -489,9 +489,7 @@ class MessagesCleanService:
|
||||
delete_result = cast(CursorResult, session.execute(delete_stmt))
|
||||
messages_deleted = delete_result.rowcount
|
||||
delete_messages_ms = int((time.monotonic() - delete_messages_start) * 1000)
|
||||
commit_start = time.monotonic()
|
||||
session.commit()
|
||||
commit_ms = int((time.monotonic() - commit_start) * 1000)
|
||||
commit_ms = 0
|
||||
|
||||
stats["total_deleted"] += messages_deleted
|
||||
batch_deleted_messages = messages_deleted
|
||||
|
||||
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
@ -46,13 +46,12 @@ class BuiltinToolManageService:
|
||||
delete custom oauth client params
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(ToolOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@ -150,7 +149,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
update builtin tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get if the provider exists
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
@ -203,9 +202,7 @@ class BuiltinToolManageService:
|
||||
|
||||
db_provider.name = name
|
||||
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
return {"result": "success"}
|
||||
|
||||
@ -222,7 +219,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
add builtin tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
try:
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
@ -281,9 +278,7 @@ class BuiltinToolManageService:
|
||||
)
|
||||
|
||||
session.add(db_provider)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
|
||||
return {"result": "success"}
|
||||
@ -379,7 +374,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
@ -393,7 +388,6 @@ class BuiltinToolManageService:
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
|
||||
session.delete(db_provider)
|
||||
session.commit()
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
@ -409,7 +403,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
set default provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get provider
|
||||
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
|
||||
if target_provider is None:
|
||||
@ -422,7 +416,6 @@ class BuiltinToolManageService:
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -654,7 +647,7 @@ class BuiltinToolManageService:
|
||||
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
|
||||
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
custom_client_params = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
@ -690,7 +683,6 @@ class BuiltinToolManageService:
|
||||
if enable_oauth_custom_client is not None:
|
||||
custom_client_params.enabled = enable_oauth_custom_client
|
||||
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -48,21 +48,25 @@ class ToolTransformService:
|
||||
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
|
||||
)
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
return str(url_prefix / "builtin" / provider_name / "icon")
|
||||
elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}:
|
||||
try:
|
||||
if isinstance(icon, str):
|
||||
parsed = emoji_icon_adapter.validate_json(icon)
|
||||
return {"background": parsed["background"], "content": parsed["content"]}
|
||||
return {"background": icon["background"], "content": icon["content"]}
|
||||
except (ValueError, ValidationError, KeyError):
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
if isinstance(icon, Mapping):
|
||||
return {"background": icon.get("background", ""), "content": icon.get("content", "")}
|
||||
return icon
|
||||
return ""
|
||||
match provider_type:
|
||||
case ToolProviderType.BUILT_IN:
|
||||
return str(url_prefix / "builtin" / provider_name / "icon")
|
||||
case ToolProviderType.API | ToolProviderType.WORKFLOW:
|
||||
try:
|
||||
if isinstance(icon, str):
|
||||
parsed = emoji_icon_adapter.validate_json(icon)
|
||||
return {"background": parsed["background"], "content": parsed["content"]}
|
||||
return {"background": icon["background"], "content": icon["content"]}
|
||||
except (ValueError, ValidationError, KeyError):
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
case ToolProviderType.MCP:
|
||||
if isinstance(icon, Mapping):
|
||||
return {"background": icon.get("background", ""), "content": icon.get("content", "")}
|
||||
return icon
|
||||
case ToolProviderType.PLUGIN | ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL:
|
||||
return ""
|
||||
case _:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
|
||||
|
||||
@ -8,7 +8,7 @@ This service centralizes all AppTrigger-related business logic.
|
||||
import logging
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.enums import AppTriggerStatus
|
||||
@ -34,13 +34,12 @@ class AppTriggerService:
|
||||
|
||||
"""
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.execute(
|
||||
update(AppTrigger)
|
||||
.where(AppTrigger.tenant_id == tenant_id, AppTrigger.status == AppTriggerStatus.ENABLED)
|
||||
.values(status=AppTriggerStatus.RATE_LIMITED)
|
||||
)
|
||||
session.commit()
|
||||
logger.info("Marked all enabled triggers as rate limited for tenant %s", tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark all enabled triggers as rate limited for tenant %s", tenant_id)
|
||||
|
||||
@ -6,7 +6,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
@ -146,7 +146,7 @@ class TriggerProviderService:
|
||||
"""
|
||||
try:
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
# Use distributed lock to prevent race conditions
|
||||
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
@ -205,7 +205,6 @@ class TriggerProviderService:
|
||||
subscription.id = subscription_id or str(uuid.uuid4())
|
||||
|
||||
session.add(subscription)
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
@ -241,7 +240,7 @@ class TriggerProviderService:
|
||||
:param expires_at: Optional new expiration timestamp
|
||||
:return: Success response with updated subscription info
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
# Use distributed lock to prevent race conditions on the same subscription
|
||||
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
@ -302,8 +301,6 @@ class TriggerProviderService:
|
||||
if expires_at is not None:
|
||||
subscription.expires_at = expires_at
|
||||
|
||||
session.commit()
|
||||
|
||||
# Clear subscription cache
|
||||
delete_cache_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
@ -404,7 +401,7 @@ class TriggerProviderService:
|
||||
:param subscription_id: Subscription instance ID
|
||||
:return: New token info
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
|
||||
if not subscription:
|
||||
@ -448,7 +445,6 @@ class TriggerProviderService:
|
||||
# Update credentials
|
||||
subscription.credentials = dict(encrypter.encrypt(dict(refreshed_credentials.credentials)))
|
||||
subscription.credential_expires_at = refreshed_credentials.expires_at
|
||||
session.commit()
|
||||
|
||||
# Clear cache
|
||||
cache.delete()
|
||||
@ -478,7 +474,7 @@ class TriggerProviderService:
|
||||
"""
|
||||
now_ts: int = int(now if now is not None else _time.time())
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
@ -531,7 +527,6 @@ class TriggerProviderService:
|
||||
# Persist refreshed properties and expires_at
|
||||
subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties)))
|
||||
subscription.expires_at = int(refreshed.expires_at)
|
||||
session.commit()
|
||||
properties_cache.delete()
|
||||
|
||||
logger.info(
|
||||
@ -639,7 +634,7 @@ class TriggerProviderService:
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# Find existing custom client params
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
@ -683,8 +678,6 @@ class TriggerProviderService:
|
||||
if enabled is not None:
|
||||
custom_client.enabled = enabled
|
||||
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
@ -733,13 +726,12 @@ class TriggerProviderService:
|
||||
:param provider_id: Provider identifier
|
||||
:return: Success response
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(TriggerOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from flask import Request, Response
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import TriggerDispatchResponse, TriggerInvokeEventResponse
|
||||
@ -215,7 +215,7 @@ class TriggerService:
|
||||
not_found_in_cache.append(node_info)
|
||||
continue
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
try:
|
||||
# lock the concurrent plugin trigger creation
|
||||
redis_client.lock(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
|
||||
@ -260,7 +260,6 @@ class TriggerService:
|
||||
cache.model_dump_json(),
|
||||
ex=60 * 60,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Update existing records if subscription_id changed
|
||||
for node_info in nodes_in_graph:
|
||||
@ -290,14 +289,12 @@ class TriggerService:
|
||||
cache.model_dump_json(),
|
||||
ex=60 * 60,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# delete the nodes not found in the graph
|
||||
for node_id in nodes_id_in_db:
|
||||
if node_id not in nodes_id_in_graph:
|
||||
session.delete(nodes_id_in_db[node_id])
|
||||
redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{app.id}:{node_id}")
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to sync plugin trigger relationships for app %s", app.id)
|
||||
raise
|
||||
|
||||
@ -12,7 +12,7 @@ from graphon.file import FileTransferMethod
|
||||
from graphon.variables.types import ArrayValidation, SegmentType
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import RequestEntityTooLarge
|
||||
|
||||
@ -598,21 +598,38 @@ class WebhookService:
|
||||
Raises:
|
||||
ValueError: If the value cannot be converted to the specified type
|
||||
"""
|
||||
if param_type == SegmentType.STRING:
|
||||
return value
|
||||
elif param_type == SegmentType.NUMBER:
|
||||
if not cls._can_convert_to_number(value):
|
||||
raise ValueError(f"Cannot convert '{value}' to number")
|
||||
numeric_value = float(value)
|
||||
return int(numeric_value) if numeric_value.is_integer() else numeric_value
|
||||
elif param_type == SegmentType.BOOLEAN:
|
||||
lower_value = value.lower()
|
||||
bool_map = {"true": True, "false": False, "1": True, "0": False, "yes": True, "no": False}
|
||||
if lower_value not in bool_map:
|
||||
raise ValueError(f"Cannot convert '{value}' to boolean")
|
||||
return bool_map[lower_value]
|
||||
else:
|
||||
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
|
||||
match param_type:
|
||||
case SegmentType.STRING:
|
||||
return value
|
||||
case SegmentType.NUMBER:
|
||||
if not cls._can_convert_to_number(value):
|
||||
raise ValueError(f"Cannot convert '{value}' to number")
|
||||
numeric_value = float(value)
|
||||
return int(numeric_value) if numeric_value.is_integer() else numeric_value
|
||||
case SegmentType.BOOLEAN:
|
||||
lower_value = value.lower()
|
||||
bool_map = {"true": True, "false": False, "1": True, "0": False, "yes": True, "no": False}
|
||||
if lower_value not in bool_map:
|
||||
raise ValueError(f"Cannot convert '{value}' to boolean")
|
||||
return bool_map[lower_value]
|
||||
case (
|
||||
SegmentType.OBJECT
|
||||
| SegmentType.FILE
|
||||
| SegmentType.ARRAY_ANY
|
||||
| SegmentType.ARRAY_STRING
|
||||
| SegmentType.ARRAY_NUMBER
|
||||
| SegmentType.ARRAY_OBJECT
|
||||
| SegmentType.ARRAY_FILE
|
||||
| SegmentType.ARRAY_BOOLEAN
|
||||
| SegmentType.SECRET
|
||||
| SegmentType.INTEGER
|
||||
| SegmentType.FLOAT
|
||||
| SegmentType.NONE
|
||||
| SegmentType.GROUP
|
||||
):
|
||||
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
|
||||
case _:
|
||||
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
|
||||
|
||||
@classmethod
|
||||
def _validate_json_value(cls, param_name: str, value: Any, param_type: SegmentType | str) -> Any:
|
||||
@ -918,7 +935,7 @@ class WebhookService:
|
||||
logger.warning("Failed to acquire lock for webhook sync, app %s", app.id)
|
||||
raise RuntimeError("Failed to acquire lock for webhook trigger synchronization")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
# fetch the non-cached nodes from DB
|
||||
all_records = session.scalars(
|
||||
select(WorkflowWebhookTrigger).where(
|
||||
@ -947,14 +964,12 @@ class WebhookService:
|
||||
redis_client.set(
|
||||
f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}", cache.model_dump_json(), ex=60 * 60
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# delete the nodes not found in the graph
|
||||
for node_id in nodes_id_in_db:
|
||||
if node_id not in nodes_id_in_graph:
|
||||
session.delete(nodes_id_in_db[node_id])
|
||||
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to sync webhook relationships for app %s", app.id)
|
||||
raise
|
||||
|
||||
@ -1075,9 +1075,8 @@ class DraftVariableSaver:
|
||||
)
|
||||
engine = bind = self._session.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
with Session(bind=engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=engine, expire_on_commit=False).begin() as session:
|
||||
session.add(variable_file)
|
||||
session.commit()
|
||||
|
||||
return truncation_result.result, variable_file
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.entities import GraphInitParams, WorkflowNodeExecution
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.enums import (
|
||||
@ -48,7 +48,12 @@ from core.workflow.human_input_compat import (
|
||||
normalize_human_input_node_data_for_graph,
|
||||
parse_human_input_delivery_methods,
|
||||
)
|
||||
from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type
|
||||
from core.workflow.node_factory import (
|
||||
LATEST_VERSION,
|
||||
DifyGraphInitContext,
|
||||
get_node_type_classes_mapping,
|
||||
is_start_node_type,
|
||||
)
|
||||
from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
@ -837,7 +842,7 @@ class WorkflowService:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
outputs = workflow_node_execution.load_full_outputs(session, storage)
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=app_model.id,
|
||||
@ -848,7 +853,6 @@ class WorkflowService:
|
||||
user=account,
|
||||
)
|
||||
draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs)
|
||||
session.commit()
|
||||
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=workflow_node_execution,
|
||||
@ -977,7 +981,7 @@ class WorkflowService:
|
||||
|
||||
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
|
||||
enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=app_model.id,
|
||||
@ -988,7 +992,6 @@ class WorkflowService:
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
)
|
||||
draft_var_saver.save(outputs=outputs, process_data={})
|
||||
session.commit()
|
||||
|
||||
return outputs
|
||||
|
||||
@ -1134,18 +1137,20 @@ class WorkflowService:
|
||||
node_config: NodeConfigDict,
|
||||
variable_pool: VariablePool,
|
||||
) -> HumanInputNode:
|
||||
graph_init_params = GraphInitParams(
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
user_id=account.id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow.id,
|
||||
graph_config=workflow.graph_dict,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
user_id=account.id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
run_context=run_context,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_init_params = graph_init_context.to_graph_init_params()
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
@ -1155,7 +1160,7 @@ class WorkflowService:
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
|
||||
runtime=DifyHumanInputNodeRuntime(run_context),
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import delete, select, update
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -30,7 +31,9 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
start_at = time.perf_counter()
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
|
||||
dataset_document = session.scalar(
|
||||
select(DatasetDocument).where(DatasetDocument.id == dataset_document_id).limit(1)
|
||||
)
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
|
||||
return
|
||||
@ -45,15 +48,14 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
if not dataset:
|
||||
raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
|
||||
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == SegmentStatus.COMPLETED,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
@ -104,18 +106,15 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
|
||||
|
||||
# delete auto disable log
|
||||
session.query(DatasetAutoDisableLog).where(
|
||||
DatasetAutoDisableLog.document_id == dataset_document.id
|
||||
).delete()
|
||||
session.execute(
|
||||
delete(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id)
|
||||
)
|
||||
|
||||
# update segment to enable
|
||||
session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
|
||||
{
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.disabled_at: None,
|
||||
DocumentSegment.disabled_by: None,
|
||||
DocumentSegment.updated_at: naive_utc_now(),
|
||||
}
|
||||
session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == dataset_document.id)
|
||||
.values(enabled=True, disabled_at=None, disabled_by=None, updated_at=naive_utc_now())
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
@ -92,14 +94,16 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
||||
# ============ Step 3: Delete metadata binding (separate short transaction) ============
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
deleted_count = int(
|
||||
session.query(DatasetMetadataBinding)
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id.in_(document_ids),
|
||||
)
|
||||
.delete(synchronize_session=False)
|
||||
result = cast(
|
||||
CursorResult,
|
||||
session.execute(
|
||||
delete(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id.in_(document_ids),
|
||||
)
|
||||
),
|
||||
)
|
||||
deleted_count = result.rowcount
|
||||
session.commit()
|
||||
logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id)
|
||||
except Exception:
|
||||
|
||||
@ -112,7 +112,9 @@ def clean_dataset_task(
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
|
||||
image_files = session.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
).all()
|
||||
for image_file in image_files:
|
||||
if image_file is None:
|
||||
continue
|
||||
@ -150,20 +152,22 @@ def clean_dataset_task(
|
||||
)
|
||||
session.execute(binding_delete_stmt)
|
||||
|
||||
session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
|
||||
session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
|
||||
session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
|
||||
session.execute(delete(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id))
|
||||
session.execute(delete(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id))
|
||||
session.execute(delete(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id))
|
||||
# delete dataset metadata
|
||||
session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
|
||||
session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
|
||||
session.execute(delete(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id))
|
||||
session.execute(delete(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id))
|
||||
# delete pipeline and workflow
|
||||
if pipeline_id:
|
||||
session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
|
||||
session.query(Workflow).where(
|
||||
Workflow.tenant_id == tenant_id,
|
||||
Workflow.app_id == pipeline_id,
|
||||
Workflow.type == WorkflowType.RAG_PIPELINE,
|
||||
).delete()
|
||||
session.execute(delete(Pipeline).where(Pipeline.id == pipeline_id))
|
||||
session.execute(
|
||||
delete(Workflow).where(
|
||||
Workflow.tenant_id == tenant_id,
|
||||
Workflow.app_id == pipeline_id,
|
||||
Workflow.type == WorkflowType.RAG_PIPELINE,
|
||||
)
|
||||
)
|
||||
# delete files
|
||||
if documents:
|
||||
file_ids = []
|
||||
@ -174,7 +178,7 @@ def clean_dataset_task(
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_ids.append(file_id)
|
||||
files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
|
||||
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
|
||||
for file in files:
|
||||
storage.delete(file.key)
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
@ -63,7 +63,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
if index_node_ids:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
@ -94,7 +94,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
if file_id:
|
||||
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
if file:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
@ -124,10 +124,12 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# delete dataset metadata binding
|
||||
session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
session.execute(
|
||||
delete(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
)
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -29,7 +29,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
@ -49,23 +49,24 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(dataset_documents_ids))
|
||||
.values(indexing_status="indexing")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
@ -82,13 +83,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="completed")
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="error", error=str(e))
|
||||
)
|
||||
session.commit()
|
||||
elif action == "update":
|
||||
@ -104,8 +109,10 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(dataset_documents_ids))
|
||||
.values(indexing_status="indexing")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@ -115,15 +122,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if segments:
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
@ -172,13 +178,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
index_processor.load(
|
||||
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||
)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="completed")
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="error", error=str(e))
|
||||
)
|
||||
session.commit()
|
||||
else:
|
||||
|
||||
@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import ConversationVariable
|
||||
@ -29,29 +30,21 @@ def delete_conversation_related_data(conversation_id: str):
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(delete(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id))
|
||||
|
||||
session.execute(delete(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id))
|
||||
|
||||
session.execute(
|
||||
delete(ToolConversationVariables).where(ToolConversationVariables.conversation_id == conversation_id)
|
||||
)
|
||||
|
||||
session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
session.execute(delete(ToolFile).where(ToolFile.conversation_id == conversation_id))
|
||||
|
||||
session.query(ToolConversationVariables).where(
|
||||
ToolConversationVariables.conversation_id == conversation_id
|
||||
).delete(synchronize_session=False)
|
||||
session.execute(delete(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id))
|
||||
|
||||
session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
|
||||
session.execute(delete(Message).where(Message.conversation_id == conversation_id))
|
||||
|
||||
session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
|
||||
|
||||
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
session.execute(delete(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id))
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
@ -29,12 +29,12 @@ def delete_segment_from_index_task(
|
||||
start_at = time.perf_counter()
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
|
||||
return
|
||||
|
||||
dataset_document = session.query(Document).where(Document.id == document_id).first()
|
||||
dataset_document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if not dataset_document:
|
||||
return
|
||||
|
||||
@ -60,11 +60,9 @@ def delete_segment_from_index_task(
|
||||
)
|
||||
if dataset.is_multimodal:
|
||||
# delete segment attachment binding
|
||||
segment_attachment_bindings = (
|
||||
session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
|
||||
.all()
|
||||
)
|
||||
segment_attachment_bindings = session.scalars(
|
||||
select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
|
||||
).all()
|
||||
if segment_attachment_bindings:
|
||||
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
|
||||
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
|
||||
@ -77,7 +75,7 @@ def delete_segment_from_index_task(
|
||||
session.execute(segment_attachment_bind_delete_stmt)
|
||||
|
||||
# delete upload file
|
||||
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
|
||||
session.execute(delete(UploadFile).where(UploadFile.id.in_(attachment_ids)))
|
||||
session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
@ -27,12 +27,12 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
||||
start_at = time.perf_counter()
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
dataset_document = session.scalar(select(DatasetDocument).where(DatasetDocument.id == document_id).limit(1))
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
|
||||
@ -58,11 +58,9 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
if dataset.is_multimodal:
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_attachment_bindings = (
|
||||
session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
|
||||
.all()
|
||||
)
|
||||
segment_attachment_bindings = session.scalars(
|
||||
select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
|
||||
).all()
|
||||
if segment_attachment_bindings:
|
||||
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
|
||||
index_node_ids.extend(attachment_ids)
|
||||
@ -87,16 +85,14 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
||||
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
# update segment error msg
|
||||
session.query(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
).update(
|
||||
{
|
||||
"disabled_at": None,
|
||||
"disabled_by": None,
|
||||
"enabled": True,
|
||||
}
|
||||
session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
)
|
||||
.values(disabled_at=None, disabled_by=None, enabled=True)
|
||||
)
|
||||
session.commit()
|
||||
finally:
|
||||
|
||||
@ -32,7 +32,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
tenant_id = None
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
document = session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
@ -42,7 +44,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
|
||||
return
|
||||
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
|
||||
@ -87,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||
@ -112,7 +114,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
try:
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if dataset:
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
|
||||
@ -120,7 +122,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
logger.exception("Failed to clean vector index for document %s", document_id)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if not document:
|
||||
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
|
||||
return
|
||||
@ -140,7 +142,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if document:
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
@ -150,7 +152,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
except Exception as e:
|
||||
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(e)
|
||||
|
||||
@ -47,7 +47,7 @@ def regenerate_summary_index_task(
|
||||
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
return
|
||||
@ -84,8 +84,8 @@ def regenerate_summary_index_task(
|
||||
# For embedding_model change: directly query all segments with existing summaries
|
||||
# Don't require document indexing_status == "completed"
|
||||
# Include summaries with status "completed" or "error" (if they have content)
|
||||
segments_with_summaries = (
|
||||
session.query(DocumentSegment, DocumentSegmentSummary)
|
||||
segments_with_summaries = session.execute(
|
||||
select(DocumentSegment, DocumentSegmentSummary)
|
||||
.join(
|
||||
DocumentSegmentSummary,
|
||||
DocumentSegment.id == DocumentSegmentSummary.chunk_id,
|
||||
@ -110,8 +110,7 @@ def regenerate_summary_index_task(
|
||||
DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
|
||||
)
|
||||
.order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if not segments_with_summaries:
|
||||
logger.info(
|
||||
@ -215,8 +214,8 @@ def regenerate_summary_index_task(
|
||||
|
||||
try:
|
||||
# Get all segments with existing summaries
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment)
|
||||
.join(
|
||||
DocumentSegmentSummary,
|
||||
DocumentSegment.id == DocumentSegmentSummary.chunk_id,
|
||||
@ -229,8 +228,7 @@ def regenerate_summary_index_task(
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if not segments:
|
||||
continue
|
||||
@ -245,13 +243,13 @@ def regenerate_summary_index_task(
|
||||
summary_record = None
|
||||
try:
|
||||
# Get existing summary record
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter_by(
|
||||
chunk_id=segment.id,
|
||||
dataset_id=dataset_id,
|
||||
summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not summary_record:
|
||||
|
||||
@ -0,0 +1,388 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.model import AccountTrialAppRecord, TrialApp
|
||||
from services import recommended_app_service as service_module
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _apps_response(
|
||||
recommended_apps: list[dict] | None = None,
|
||||
categories: list[str] | None = None,
|
||||
) -> dict:
|
||||
if recommended_apps is None:
|
||||
recommended_apps = [
|
||||
{"id": "app-1", "name": "Test App 1", "description": "d1", "category": "productivity"},
|
||||
{"id": "app-2", "name": "Test App 2", "description": "d2", "category": "communication"},
|
||||
]
|
||||
if categories is None:
|
||||
categories = ["productivity", "communication", "utilities"]
|
||||
return {"recommended_apps": recommended_apps, "categories": categories}
|
||||
|
||||
|
||||
def _app_detail(
|
||||
app_id: str = "app-123",
|
||||
name: str = "Test App",
|
||||
description: str = "Test description",
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
detail: dict[str, Any] = {
|
||||
"id": app_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"category": kwargs.get("category", "productivity"),
|
||||
"icon": kwargs.get("icon", "🚀"),
|
||||
"model_config": kwargs.get("model_config", {}),
|
||||
}
|
||||
detail.update(kwargs)
|
||||
return detail
|
||||
|
||||
|
||||
def _recommendation_detail(result: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
return cast("dict[str, Any] | None", result)
|
||||
|
||||
|
||||
def _mock_factory_for_apps(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
*,
|
||||
mode: str,
|
||||
result: dict[str, Any],
|
||||
fallback_result: dict[str, Any] | None = None,
|
||||
) -> tuple[MagicMock, MagicMock]:
|
||||
retrieval_instance = MagicMock()
|
||||
retrieval_instance.get_recommended_apps_and_categories.return_value = result
|
||||
retrieval_factory = MagicMock(return_value=retrieval_instance)
|
||||
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", mode, raising=False)
|
||||
monkeypatch.setattr(
|
||||
service_module.RecommendAppRetrievalFactory,
|
||||
"get_recommend_app_factory",
|
||||
MagicMock(return_value=retrieval_factory),
|
||||
)
|
||||
builtin_instance = MagicMock()
|
||||
if fallback_result is not None:
|
||||
builtin_instance.fetch_recommended_apps_from_builtin.return_value = fallback_result
|
||||
monkeypatch.setattr(
|
||||
service_module.RecommendAppRetrievalFactory,
|
||||
"get_buildin_recommend_app_retrieval",
|
||||
MagicMock(return_value=builtin_instance),
|
||||
)
|
||||
return retrieval_instance, builtin_instance
|
||||
|
||||
|
||||
# ── Pure logic tests: get_recommended_apps_and_categories ──────────────
|
||||
|
||||
|
||||
class TestRecommendedAppServiceGetApps:
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_success_with_apps(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
expected = _apps_response()
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommended_apps_and_categories.return_value = expected
|
||||
mock_factory = MagicMock(return_value=mock_instance)
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
assert result == expected
|
||||
assert len(result["recommended_apps"]) == 2
|
||||
assert len(result["categories"]) == 3
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
|
||||
mock_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
empty_response = {"recommended_apps": [], "categories": []}
|
||||
builtin_response = _apps_response(
|
||||
recommended_apps=[{"id": "builtin-1", "name": "Builtin App", "category": "default"}]
|
||||
)
|
||||
|
||||
mock_remote_instance = MagicMock()
|
||||
mock_remote_instance.get_recommended_apps_and_categories.return_value = empty_response
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_remote_instance)
|
||||
|
||||
mock_builtin_instance = MagicMock()
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
|
||||
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
|
||||
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN")
|
||||
|
||||
assert result == builtin_response
|
||||
assert result["recommended_apps"][0]["id"] == "builtin-1"
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db"
|
||||
none_response = {"recommended_apps": None, "categories": ["test"]}
|
||||
builtin_response = _apps_response()
|
||||
|
||||
mock_db_instance = MagicMock()
|
||||
mock_db_instance.get_recommended_apps_and_categories.return_value = none_response
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_db_instance)
|
||||
|
||||
mock_builtin_instance = MagicMock()
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
|
||||
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
|
||||
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
assert result == builtin_response
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_different_languages(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
|
||||
|
||||
for language in ["en-US", "zh-CN", "ja-JP", "fr-FR"]:
|
||||
lang_response = _apps_response(
|
||||
recommended_apps=[{"id": f"app-{language}", "name": f"App {language}", "category": "test"}]
|
||||
)
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommended_apps_and_categories.return_value = lang_response
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
|
||||
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories(language)
|
||||
|
||||
assert result["recommended_apps"][0]["id"] == f"app-{language}"
|
||||
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_uses_correct_factory_mode(self, mock_config, mock_factory_class):
|
||||
for mode in ["remote", "builtin", "db"]:
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
|
||||
response = _apps_response()
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommended_apps_and_categories.return_value = response
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
|
||||
|
||||
RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
|
||||
# ── Pure logic tests: get_recommend_app_detail ─────────────────────────
|
||||
|
||||
|
||||
class TestRecommendedAppServiceGetDetail:
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_success(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
expected = _app_detail(app_id="app-123", name="Productivity App", description="A great app")
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = expected
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
|
||||
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("app-123"))
|
||||
|
||||
assert result == expected
|
||||
assert result["id"] == "app-123"
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with("app-123")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_different_modes(self, mock_config, mock_factory_class):
|
||||
for mode in ["remote", "builtin", "db"]:
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
|
||||
detail = _app_detail(app_id="test-app", name=f"App from {mode}")
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = detail
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
|
||||
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("test-app"))
|
||||
|
||||
assert result["name"] == f"App from {mode}"
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_returns_none_when_not_found(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = None
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
|
||||
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("nonexistent"))
|
||||
|
||||
assert result is None
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with("nonexistent")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_returns_empty_dict(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = {}
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
|
||||
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("app-empty"))
|
||||
|
||||
assert result == {}
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_complex_model_config(self, mock_config, mock_factory_class):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
complex_config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"parameters": {"temperature": 0.7, "max_tokens": 2000, "top_p": 1.0},
|
||||
}
|
||||
expected = _app_detail(
|
||||
app_id="complex-app",
|
||||
name="Complex App",
|
||||
model_config=complex_config,
|
||||
workflows=["workflow-1", "workflow-2"],
|
||||
tools=["tool-1", "tool-2", "tool-3"],
|
||||
)
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = expected
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
|
||||
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("complex-app"))
|
||||
|
||||
assert result["model_config"] == complex_config
|
||||
assert len(result["workflows"]) == 2
|
||||
assert len(result["tools"]) == 3
|
||||
|
||||
|
||||
# ── Integration tests: trial app features (real DB) ────────────────────
|
||||
|
||||
|
||||
class TestRecommendedAppServiceTrialFeatures:
|
||||
def test_get_apps_should_not_query_trial_table_when_disabled(
|
||||
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
expected = {"recommended_apps": [{"app_id": "app-1"}], "categories": ["all"]}
|
||||
retrieval_instance, builtin_instance = _mock_factory_for_apps(monkeypatch, mode="remote", result=expected)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_system_features",
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=False)),
|
||||
)
|
||||
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
assert result == expected
|
||||
retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
|
||||
builtin_instance.fetch_recommended_apps_from_builtin.assert_not_called()
|
||||
|
||||
def test_get_apps_should_enrich_can_trial_when_enabled(
|
||||
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
app_id_1 = str(uuid.uuid4())
|
||||
app_id_2 = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
# app_id_1 has a TrialApp record; app_id_2 does not
|
||||
db_session_with_containers.add(TrialApp(app_id=app_id_1, tenant_id=tenant_id))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
remote_result = {"recommended_apps": [], "categories": []}
|
||||
fallback_result = {
|
||||
"recommended_apps": [{"app_id": app_id_1}, {"app_id": app_id_2}],
|
||||
"categories": ["all"],
|
||||
}
|
||||
_, builtin_instance = _mock_factory_for_apps(
|
||||
monkeypatch, mode="remote", result=remote_result, fallback_result=fallback_result
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_system_features",
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
|
||||
)
|
||||
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP")
|
||||
|
||||
builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
|
||||
assert result["recommended_apps"][0]["can_trial"] is True
|
||||
assert result["recommended_apps"][1]["can_trial"] is False
|
||||
|
||||
@pytest.mark.parametrize("has_trial_app", [True, False])
|
||||
def test_get_detail_should_set_can_trial_when_enabled(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
has_trial_app: bool,
|
||||
):
|
||||
app_id = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
if has_trial_app:
|
||||
db_session_with_containers.add(TrialApp(app_id=app_id, tenant_id=tenant_id))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
detail = {"id": app_id, "name": "Test App"}
|
||||
retrieval_instance = MagicMock()
|
||||
retrieval_instance.get_recommend_app_detail.return_value = detail
|
||||
retrieval_factory = MagicMock(return_value=retrieval_instance)
|
||||
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", "remote", raising=False)
|
||||
monkeypatch.setattr(
|
||||
service_module.RecommendAppRetrievalFactory,
|
||||
"get_recommend_app_factory",
|
||||
MagicMock(return_value=retrieval_factory),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_system_features",
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
|
||||
)
|
||||
|
||||
result = cast(dict[str, Any], RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
assert result["id"] == app_id
|
||||
assert result["can_trial"] is has_trial_app
|
||||
|
||||
def test_add_trial_app_record_increments_count_for_existing(self, db_session_with_containers: Session):
|
||||
app_id = str(uuid.uuid4())
|
||||
account_id = str(uuid.uuid4())
|
||||
|
||||
db_session_with_containers.add(AccountTrialAppRecord(app_id=app_id, account_id=account_id, count=3))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
RecommendedAppService.add_trial_app_record(app_id, account_id)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
record = db_session_with_containers.scalar(
|
||||
select(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
|
||||
.limit(1)
|
||||
)
|
||||
assert record is not None
|
||||
assert record.count == 4
|
||||
|
||||
def test_add_trial_app_record_creates_new_record(self, db_session_with_containers: Session):
|
||||
app_id = str(uuid.uuid4())
|
||||
account_id = str(uuid.uuid4())
|
||||
|
||||
RecommendedAppService.add_trial_app_record(app_id, account_id)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
record = db_session_with_containers.scalar(
|
||||
select(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
|
||||
.limit(1)
|
||||
)
|
||||
assert record is not None
|
||||
assert record.app_id == app_id
|
||||
assert record.account_id == account_id
|
||||
assert record.count == 1
|
||||
@ -134,6 +134,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
@ -150,7 +151,9 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
@ -177,7 +180,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
# Note: Since we're mocking ConversationVariable.from_variable,
|
||||
# we can't directly check the id, but we can verify add_all was called
|
||||
assert mock_session.add_all.called, "Session add_all should have been called"
|
||||
assert mock_session.commit.called, "Session commit should have been called"
|
||||
|
||||
def test_no_variables_creates_all(self):
|
||||
"""Test that all conversation variables are created when none exist in DB."""
|
||||
@ -278,6 +280,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
@ -295,7 +298,9 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
@ -326,7 +331,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
# Verify that all variables were created
|
||||
assert len(added_items) == 2, "Should have added both variables"
|
||||
assert mock_session.add_all.called, "Session add_all should have been called"
|
||||
assert mock_session.commit.called, "Session commit should have been called"
|
||||
|
||||
def test_all_variables_exist_no_changes(self):
|
||||
"""Test that no changes are made when all variables already exist in DB."""
|
||||
@ -429,6 +433,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
@ -445,7 +450,9 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
@ -465,4 +472,3 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Verify that no variables were added
|
||||
assert not mock_session.add_all.called, "Session add_all should not have been called"
|
||||
assert mock_session.commit.called, "Session commit should still be called"
|
||||
|
||||
@ -93,6 +93,16 @@ def _patch_common_run_deps(runner: AdvancedChatAppRunner):
|
||||
scalar=lambda *a, **k: MagicMock(),
|
||||
),
|
||||
),
|
||||
sessionmaker=MagicMock(
|
||||
return_value=MagicMock(
|
||||
begin=MagicMock(
|
||||
return_value=MagicMock(
|
||||
__enter__=lambda s: MagicMock(scalars=MagicMock(return_value=MagicMock(all=lambda: []))),
|
||||
__exit__=lambda *a, **k: False,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
select=MagicMock(),
|
||||
db=MagicMock(engine=MagicMock()),
|
||||
RedisChannel=MagicMock(),
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus
|
||||
@ -610,33 +611,33 @@ class TestWorkflowGenerateTaskPipeline:
|
||||
|
||||
def test_database_session_rolls_back_on_error(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
calls = {"commit": 0, "rollback": 0}
|
||||
|
||||
class _Session:
|
||||
def __init__(self, *args, **kwargs):
|
||||
_ = args, kwargs
|
||||
calls = {"enter": 0, "exit_exc": None}
|
||||
|
||||
class _BeginContext:
|
||||
def __enter__(self):
|
||||
return self
|
||||
calls["enter"] += 1
|
||||
return MagicMock()
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
calls["exit_exc"] = exc_type
|
||||
return False
|
||||
|
||||
def commit(self):
|
||||
calls["commit"] += 1
|
||||
class _Sessionmaker:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def rollback(self):
|
||||
calls["rollback"] += 1
|
||||
def begin(self):
|
||||
return _BeginContext()
|
||||
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session)
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.sessionmaker", _Sessionmaker)
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object()))
|
||||
|
||||
with pytest.raises(RuntimeError, match="db error"):
|
||||
with pipeline._database_session():
|
||||
raise RuntimeError("db error")
|
||||
|
||||
assert calls["commit"] == 0
|
||||
assert calls["rollback"] == 1
|
||||
assert calls["enter"] == 1
|
||||
assert calls["exit_exc"] is RuntimeError
|
||||
|
||||
def test_node_retry_and_started_handlers_cover_none_and_value(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
@ -71,7 +71,7 @@ def test_vector_methods_delegate_to_underlying_implementation():
|
||||
assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value
|
||||
vector.delete()
|
||||
|
||||
runner._create_collection_if_not_exists.assert_called_once_with(2)
|
||||
runner.create_collection_if_not_exists.assert_called_once_with(2)
|
||||
runner.add_texts.assert_any_call(texts, [[0.1, 0.2]])
|
||||
runner.delete_by_ids.assert_called_once_with(["d1"])
|
||||
runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1")
|
||||
|
||||
@ -249,7 +249,7 @@ def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
|
||||
vector._client = MagicMock()
|
||||
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404)
|
||||
|
||||
vector._create_collection_if_not_exists(embedding_dimension=1024)
|
||||
vector.create_collection_if_not_exists(embedding_dimension=1024)
|
||||
|
||||
vector._client.create_collection.assert_called_once()
|
||||
openapi_module.redis_client.set.assert_called_once()
|
||||
@ -268,7 +268,7 @@ def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
|
||||
vector.config = _config()
|
||||
vector._client = MagicMock()
|
||||
|
||||
vector._create_collection_if_not_exists(embedding_dimension=1024)
|
||||
vector.create_collection_if_not_exists(embedding_dimension=1024)
|
||||
|
||||
vector._client.describe_collection.assert_not_called()
|
||||
vector._client.create_collection.assert_not_called()
|
||||
@ -290,7 +290,7 @@ def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
|
||||
vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500)
|
||||
|
||||
with pytest.raises(ValueError, match="failed to create collection collection_1"):
|
||||
vector._create_collection_if_not_exists(embedding_dimension=512)
|
||||
vector.create_collection_if_not_exists(embedding_dimension=512)
|
||||
|
||||
|
||||
def test_openapi_add_delete_and_search_methods(monkeypatch):
|
||||
|
||||
@ -374,7 +374,7 @@ def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeyp
|
||||
|
||||
vector._get_cursor = _cursor_context
|
||||
|
||||
vector._create_collection_if_not_exists(embedding_dimension=3)
|
||||
vector.create_collection_if_not_exists(embedding_dimension=3)
|
||||
|
||||
assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list)
|
||||
assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list)
|
||||
@ -404,7 +404,7 @@ def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypat
|
||||
vector._get_cursor = _cursor_context
|
||||
|
||||
with pytest.raises(RuntimeError, match="permission denied"):
|
||||
vector._create_collection_if_not_exists(embedding_dimension=3)
|
||||
vector.create_collection_if_not_exists(embedding_dimension=3)
|
||||
|
||||
|
||||
def test_delete_methods_raise_when_error_is_not_missing_table():
|
||||
|
||||
@ -53,6 +53,31 @@ def _session_factory(calls, execute_results=None):
|
||||
return _session
|
||||
|
||||
|
||||
class _FakeBeginContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
def _sessionmaker_factory(calls, execute_results=None):
|
||||
def _sessionmaker(*args, **kwargs):
|
||||
session = _FakeSessionContext(calls=calls, execute_results=execute_results)
|
||||
return MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
|
||||
|
||||
return _sessionmaker
|
||||
|
||||
|
||||
def _patch_both(monkeypatch, module, calls, execute_results=None):
|
||||
"""Patch both Session and sessionmaker on the module with the same call tracker."""
|
||||
monkeypatch.setattr(module, "Session", _session_factory(calls, execute_results))
|
||||
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(calls, execute_results))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pgvecto_module(monkeypatch):
|
||||
for name, module in _build_fake_pgvecto_modules().items():
|
||||
@ -105,7 +130,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
|
||||
_patch_both(monkeypatch, module, session_calls)
|
||||
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
vector.create_collection = MagicMock()
|
||||
@ -124,7 +149,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(session_calls))
|
||||
_patch_both(monkeypatch, module, session_calls)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
@ -151,10 +176,10 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])]
|
||||
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
|
||||
_patch_both(monkeypatch, module, init_calls)
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results)))
|
||||
_patch_both(monkeypatch, module, runtime_calls, execute_results=list(execute_results))
|
||||
|
||||
class _InsertBuilder:
|
||||
def __init__(self, table):
|
||||
@ -179,6 +204,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
"Session",
|
||||
_session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]),
|
||||
)
|
||||
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
|
||||
assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
@ -204,12 +230,13 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
],
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls))
|
||||
vector.delete_by_ids(["doc-1"])
|
||||
assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls)
|
||||
assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls)
|
||||
|
||||
runtime_calls.clear()
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()]))
|
||||
_patch_both(monkeypatch, module, runtime_calls, execute_results=[MagicMock()])
|
||||
vector.delete()
|
||||
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
|
||||
|
||||
@ -218,7 +245,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
|
||||
module, _ = pgvecto_module
|
||||
init_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
monkeypatch.setattr(module, "Session", _session_factory(init_calls))
|
||||
_patch_both(monkeypatch, module, init_calls)
|
||||
vector = module.PGVectoRS("collection_1", _config(module), dim=3)
|
||||
|
||||
runtime_calls = []
|
||||
@ -277,7 +304,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
|
||||
(SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1),
|
||||
(SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8),
|
||||
]
|
||||
monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows]))
|
||||
_patch_both(monkeypatch, module, runtime_calls, execute_results=[rows])
|
||||
|
||||
docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"])
|
||||
assert len(docs) == 1
|
||||
|
||||
@ -4909,15 +4909,17 @@ class TestInternalHooksCoverage:
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = False
|
||||
|
||||
sessionmaker_ctx = MagicMock()
|
||||
sessionmaker_ctx.begin.return_value = session_ctx
|
||||
|
||||
with (
|
||||
patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())),
|
||||
patch("core.rag.retrieval.dataset_retrieval.Session", return_value=session_ctx),
|
||||
patch("core.rag.retrieval.dataset_retrieval.sessionmaker", return_value=sessionmaker_ctx),
|
||||
patch.object(retrieval, "_send_trace_task") as mock_trace,
|
||||
):
|
||||
retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1})
|
||||
|
||||
query.update.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
mock_trace.assert_called_once()
|
||||
|
||||
def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None:
|
||||
|
||||
@ -129,7 +129,7 @@ def test_get_file_binary_returns_none_when_not_found() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -144,7 +144,7 @@ def test_get_file_binary_returns_bytes_when_found() -> None:
|
||||
manager = ToolFileManager()
|
||||
tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain")
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = tool_file
|
||||
session.scalar.return_value = tool_file
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
@ -160,11 +160,7 @@ def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = None
|
||||
second_query.where.return_value.first.return_value = None
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
session.scalar.side_effect = [None, None]
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -179,11 +175,7 @@ def test_get_file_binary_by_message_file_id_when_url_is_none() -> None:
|
||||
manager = ToolFileManager()
|
||||
message_file = SimpleNamespace(url=None)
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = message_file
|
||||
second_query.where.return_value.first.return_value = None
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
session.scalar.side_effect = [message_file, None]
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -199,11 +191,7 @@ def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None:
|
||||
message_file = SimpleNamespace(url="https://x/files/tools/tool123.png")
|
||||
tool_file = SimpleNamespace(file_key="k2", mimetype="image/png")
|
||||
session = Mock()
|
||||
first_query = Mock()
|
||||
second_query = Mock()
|
||||
first_query.where.return_value.first.return_value = message_file
|
||||
second_query.where.return_value.first.return_value = tool_file
|
||||
session.query.side_effect = [first_query, second_query]
|
||||
session.scalar.side_effect = [message_file, tool_file]
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
@ -219,7 +207,7 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
with _patch_session_factory(session):
|
||||
@ -242,7 +230,7 @@ def test_get_file_generator_returns_stream_when_found() -> None:
|
||||
size=12,
|
||||
)
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = tool_file
|
||||
session.scalar.return_value = tool_file
|
||||
|
||||
# Act
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
|
||||
@ -43,7 +43,7 @@ def test_get_db_provider_tool_builds_entity():
|
||||
controller = _controller()
|
||||
session = Mock()
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={})
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
app = SimpleNamespace(id="app-1")
|
||||
db_provider = SimpleNamespace(
|
||||
id="provider-1",
|
||||
@ -136,7 +136,7 @@ def test_from_db_builds_controller():
|
||||
parameter_configurations=[],
|
||||
)
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
session.get.side_effect = [app, user]
|
||||
fake_cm = MagicMock()
|
||||
fake_cm.__enter__.return_value = session
|
||||
@ -163,7 +163,7 @@ def test_get_tools_returns_empty_when_provider_missing():
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
session_cls.return_value.__enter__.return_value = session
|
||||
|
||||
assert controller.get_tools("tenant-1") == []
|
||||
@ -189,7 +189,7 @@ def test_get_tools_raises_when_app_missing():
|
||||
mock_db.engine = object()
|
||||
with patch("core.tools.workflow_as_tool.provider.Session") as session_cls:
|
||||
session = _mock_session_with_begin()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
session.get.return_value = None
|
||||
session_cls.return_value.__enter__.return_value = session
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
|
||||
@ -110,6 +110,34 @@ class TestFetchMemory:
|
||||
)
|
||||
|
||||
|
||||
class TestDifyGraphInitContext:
|
||||
def test_to_graph_init_params_preserves_explicit_values(self):
|
||||
run_context = {
|
||||
DIFY_RUN_CONTEXT_KEY: DifyRunContext(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
"extra": "value",
|
||||
}
|
||||
graph_config = {"nodes": [], "edges": []}
|
||||
graph_init_context = node_factory.DifyGraphInitContext(
|
||||
workflow_id="workflow-id",
|
||||
graph_config=graph_config,
|
||||
run_context=run_context,
|
||||
call_depth=2,
|
||||
)
|
||||
|
||||
result = graph_init_context.to_graph_init_params()
|
||||
|
||||
assert result.workflow_id == "workflow-id"
|
||||
assert result.graph_config == graph_config
|
||||
assert result.run_context == run_context
|
||||
assert result.call_depth == 2
|
||||
|
||||
|
||||
class TestDefaultWorkflowCodeExecutor:
|
||||
def test_execute_delegates_to_code_executor(self, monkeypatch):
|
||||
executor = node_factory.DefaultWorkflowCodeExecutor()
|
||||
@ -172,6 +200,23 @@ class TestCodeExecutorJinja2TemplateRenderer:
|
||||
|
||||
|
||||
class TestDifyNodeFactoryInit:
|
||||
def test_from_graph_init_context_translates_before_init(self):
|
||||
graph_init_context = MagicMock()
|
||||
graph_init_context.to_graph_init_params.return_value = sentinel.graph_init_params
|
||||
|
||||
with patch.object(node_factory.DifyNodeFactory, "__init__", return_value=None) as init:
|
||||
factory = node_factory.DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
|
||||
assert isinstance(factory, node_factory.DifyNodeFactory)
|
||||
graph_init_context.to_graph_init_params.assert_called_once_with()
|
||||
init.assert_called_once_with(
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
|
||||
def test_init_builds_default_dependencies(self):
|
||||
graph_init_params = SimpleNamespace(run_context={"context": "value"})
|
||||
graph_runtime_state = sentinel.graph_runtime_state
|
||||
|
||||
@ -349,7 +349,7 @@ class TestWorkflowEntrySingleStepRun:
|
||||
]
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
|
||||
patch.object(
|
||||
workflow_entry,
|
||||
"GraphRuntimeState",
|
||||
@ -358,7 +358,7 @@ class TestWorkflowEntrySingleStepRun:
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeLLMNode),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
|
||||
patch.object(workflow_entry, "load_into_variable_pool"),
|
||||
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
|
||||
patch.object(
|
||||
@ -412,12 +412,12 @@ class TestWorkflowEntrySingleStepRun:
|
||||
raise NotImplementedError
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
|
||||
patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool,
|
||||
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
|
||||
patch.object(
|
||||
@ -481,12 +481,12 @@ class TestWorkflowEntrySingleStepRun:
|
||||
return {"question": ["node", "question"]}
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeDatasourceNode),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
|
||||
patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool,
|
||||
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
|
||||
patch.object(
|
||||
@ -541,12 +541,12 @@ class TestWorkflowEntrySingleStepRun:
|
||||
return "1"
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
|
||||
patch.object(workflow_entry, "add_node_inputs_to_pool"),
|
||||
patch.object(workflow_entry, "load_into_variable_pool"),
|
||||
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
|
||||
@ -651,14 +651,18 @@ class TestWorkflowEntryHelpers:
|
||||
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls,
|
||||
patch.object(workflow_entry, "add_variables_to_pool") as add_variables_to_pool,
|
||||
patch.object(
|
||||
workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params
|
||||
) as graph_init_params,
|
||||
workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context
|
||||
) as graph_init_context_cls,
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(
|
||||
workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}
|
||||
) as build_dify_run_context,
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory) as dify_node_factory_cls,
|
||||
patch.object(
|
||||
workflow_entry.DifyNodeFactory,
|
||||
"from_graph_init_context",
|
||||
return_value=dify_node_factory,
|
||||
) as dify_node_factory_cls,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
@ -688,7 +692,7 @@ class TestWorkflowEntryHelpers:
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
graph_init_params.assert_called_once_with(
|
||||
graph_init_context_cls.assert_called_once_with(
|
||||
workflow_id="",
|
||||
graph_config=workflow_entry.WorkflowEntry._create_single_node_graph(
|
||||
"node-id", {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, "title": "Node"}
|
||||
@ -697,7 +701,7 @@ class TestWorkflowEntryHelpers:
|
||||
call_depth=0,
|
||||
)
|
||||
dify_node_factory_cls.assert_called_once_with(
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_init_context=sentinel.graph_init_context,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
mapping_user_inputs_to_variable_pool.assert_called_once_with(
|
||||
@ -734,11 +738,15 @@ class TestWorkflowEntryHelpers:
|
||||
patch.object(workflow_entry, "default_system_variables", return_value=sentinel.system_variables),
|
||||
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool),
|
||||
patch.object(workflow_entry, "add_variables_to_pool"),
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory),
|
||||
patch.object(
|
||||
workflow_entry.DifyNodeFactory,
|
||||
"from_graph_init_context",
|
||||
return_value=dify_node_factory,
|
||||
),
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
|
||||
@ -1,53 +1,125 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from redis import RedisError
|
||||
from redis.retry import Retry
|
||||
|
||||
from extensions.ext_redis import redis_fallback
|
||||
from extensions.ext_redis import (
|
||||
_get_base_redis_params,
|
||||
_get_cluster_connection_health_params,
|
||||
_get_connection_health_params,
|
||||
redis_fallback,
|
||||
)
|
||||
|
||||
|
||||
def test_redis_fallback_success():
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
return "success"
|
||||
class TestGetConnectionHealthParams:
|
||||
@patch("extensions.ext_redis.dify_config")
|
||||
def test_includes_all_health_params(self, mock_config):
|
||||
mock_config.REDIS_RETRY_RETRIES = 3
|
||||
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
|
||||
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
|
||||
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
|
||||
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
|
||||
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
|
||||
|
||||
assert test_func() == "success"
|
||||
params = _get_connection_health_params()
|
||||
|
||||
assert "retry" in params
|
||||
assert "socket_timeout" in params
|
||||
assert "socket_connect_timeout" in params
|
||||
assert "health_check_interval" in params
|
||||
assert isinstance(params["retry"], Retry)
|
||||
assert params["retry"]._retries == 3
|
||||
assert params["socket_timeout"] == 5.0
|
||||
assert params["socket_connect_timeout"] == 5.0
|
||||
assert params["health_check_interval"] == 30
|
||||
|
||||
|
||||
def test_redis_fallback_error():
|
||||
@redis_fallback(default_return="fallback")
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
class TestGetClusterConnectionHealthParams:
|
||||
@patch("extensions.ext_redis.dify_config")
|
||||
def test_excludes_health_check_interval(self, mock_config):
|
||||
mock_config.REDIS_RETRY_RETRIES = 3
|
||||
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
|
||||
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
|
||||
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
|
||||
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
|
||||
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
|
||||
|
||||
assert test_func() == "fallback"
|
||||
params = _get_cluster_connection_health_params()
|
||||
|
||||
assert "retry" in params
|
||||
assert "socket_timeout" in params
|
||||
assert "socket_connect_timeout" in params
|
||||
assert "health_check_interval" not in params
|
||||
|
||||
|
||||
def test_redis_fallback_none_default():
|
||||
@redis_fallback()
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
class TestGetBaseRedisParams:
|
||||
@patch("extensions.ext_redis.dify_config")
|
||||
def test_includes_retry_and_health_params(self, mock_config):
|
||||
mock_config.REDIS_USERNAME = None
|
||||
mock_config.REDIS_PASSWORD = None
|
||||
mock_config.REDIS_DB = 0
|
||||
mock_config.REDIS_SERIALIZATION_PROTOCOL = 3
|
||||
mock_config.REDIS_ENABLE_CLIENT_SIDE_CACHE = False
|
||||
mock_config.REDIS_RETRY_RETRIES = 3
|
||||
mock_config.REDIS_RETRY_BACKOFF_BASE = 1.0
|
||||
mock_config.REDIS_RETRY_BACKOFF_CAP = 10.0
|
||||
mock_config.REDIS_SOCKET_TIMEOUT = 5.0
|
||||
mock_config.REDIS_SOCKET_CONNECT_TIMEOUT = 5.0
|
||||
mock_config.REDIS_HEALTH_CHECK_INTERVAL = 30
|
||||
|
||||
assert test_func() is None
|
||||
params = _get_base_redis_params()
|
||||
|
||||
assert "retry" in params
|
||||
assert isinstance(params["retry"], Retry)
|
||||
assert params["socket_timeout"] == 5.0
|
||||
assert params["socket_connect_timeout"] == 5.0
|
||||
assert params["health_check_interval"] == 30
|
||||
# Existing params still present
|
||||
assert params["db"] == 0
|
||||
assert params["encoding"] == "utf-8"
|
||||
|
||||
|
||||
def test_redis_fallback_with_args():
|
||||
@redis_fallback(default_return=0)
|
||||
def test_func(x, y):
|
||||
raise RedisError("Redis error")
|
||||
class TestRedisFallback:
|
||||
def test_redis_fallback_success(self):
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
assert test_func(1, 2) == 0
|
||||
assert test_func() == "success"
|
||||
|
||||
def test_redis_fallback_error(self):
|
||||
@redis_fallback(default_return="fallback")
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
|
||||
def test_redis_fallback_with_kwargs():
|
||||
@redis_fallback(default_return={})
|
||||
def test_func(x=None, y=None):
|
||||
raise RedisError("Redis error")
|
||||
assert test_func() == "fallback"
|
||||
|
||||
assert test_func(x=1, y=2) == {}
|
||||
def test_redis_fallback_none_default(self):
|
||||
@redis_fallback()
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func() is None
|
||||
|
||||
def test_redis_fallback_preserves_function_metadata():
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
"""Test function docstring"""
|
||||
pass
|
||||
def test_redis_fallback_with_args(self):
|
||||
@redis_fallback(default_return=0)
|
||||
def test_func(x, y):
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func.__name__ == "test_func"
|
||||
assert test_func.__doc__ == "Test function docstring"
|
||||
assert test_func(1, 2) == 0
|
||||
|
||||
def test_redis_fallback_with_kwargs(self):
|
||||
@redis_fallback(default_return={})
|
||||
def test_func(x=None, y=None):
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func(x=1, y=2) == {}
|
||||
|
||||
def test_redis_fallback_preserves_function_metadata(self):
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
"""Test function docstring"""
|
||||
pass
|
||||
|
||||
assert test_func.__name__ == "test_func"
|
||||
assert test_func.__doc__ == "Test function docstring"
|
||||
|
||||
@ -6,12 +6,12 @@ MODULE = "services.plugin.plugin_auto_upgrade_service"
|
||||
|
||||
|
||||
def _patched_session():
|
||||
"""Patch Session(db.engine) to return a mock session as context manager."""
|
||||
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
|
||||
session = MagicMock()
|
||||
session_cls = MagicMock()
|
||||
session_cls.return_value.__enter__ = MagicMock(return_value=session)
|
||||
session_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.Session", session_cls)
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
|
||||
db_patcher = patch(f"{MODULE}.db")
|
||||
return patcher, db_patcher, session
|
||||
|
||||
@ -20,7 +20,7 @@ class TestGetStrategy:
|
||||
def test_returns_strategy_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
strategy = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = strategy
|
||||
session.scalar.return_value = strategy
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
@ -31,7 +31,7 @@ class TestGetStrategy:
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
@ -44,9 +44,9 @@ class TestGetStrategy:
|
||||
class TestChangeStrategy:
|
||||
def test_creates_new_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
@ -61,12 +61,11 @@ class TestChangeStrategy:
|
||||
|
||||
assert result is True
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_updates_existing_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
@ -86,17 +85,17 @@ class TestChangeStrategy:
|
||||
assert existing.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
|
||||
assert existing.exclude_plugins == ["p1"]
|
||||
assert existing.include_plugins == ["p2"]
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestExcludePlugin:
|
||||
def test_creates_default_strategy_when_none_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with (
|
||||
p1,
|
||||
p2,
|
||||
patch(f"{MODULE}.select"),
|
||||
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
|
||||
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
|
||||
):
|
||||
@ -115,9 +114,9 @@ class TestExcludePlugin:
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p-existing"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
@ -127,16 +126,15 @@ class TestExcludePlugin:
|
||||
|
||||
assert result is True
|
||||
assert existing.exclude_plugins == ["p-existing", "p-new"]
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_removes_from_include_list_in_partial_mode(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "partial"
|
||||
existing.include_plugins = ["p1", "p2"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
@ -151,9 +149,9 @@ class TestExcludePlugin:
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "all"
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
@ -170,9 +168,9 @@ class TestExcludePlugin:
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p1"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
|
||||
@ -6,12 +6,12 @@ MODULE = "services.plugin.plugin_permission_service"
|
||||
|
||||
|
||||
def _patched_session():
|
||||
"""Patch Session(db.engine) to return a mock session as context manager."""
|
||||
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
|
||||
session = MagicMock()
|
||||
session_cls = MagicMock()
|
||||
session_cls.return_value.__enter__ = MagicMock(return_value=session)
|
||||
session_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.Session", session_cls)
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
|
||||
db_patcher = patch(f"{MODULE}.db")
|
||||
return patcher, db_patcher, session
|
||||
|
||||
@ -20,7 +20,7 @@ class TestGetPermission:
|
||||
def test_returns_permission_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
permission = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = permission
|
||||
session.scalar.return_value = permission
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
@ -31,7 +31,7 @@ class TestGetPermission:
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
@ -44,9 +44,9 @@ class TestGetPermission:
|
||||
class TestChangePermission:
|
||||
def test_creates_new_permission_when_not_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||
perm_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
@ -55,12 +55,11 @@ class TestChangePermission:
|
||||
)
|
||||
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_updates_existing_permission(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
@ -71,5 +70,4 @@ class TestChangePermission:
|
||||
|
||||
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||
session.commit.assert_called_once()
|
||||
session.add.assert_not_called()
|
||||
|
||||
@ -275,48 +275,46 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -
|
||||
msg_session_1.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock()
|
||||
)
|
||||
msg_session_1.commit.return_value = None
|
||||
|
||||
msg_session_2 = MagicMock()
|
||||
msg_session_2.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[]]) if model == service_module.Message else MagicMock()
|
||||
)
|
||||
msg_session_2.commit.return_value = None
|
||||
|
||||
conv_session_1 = MagicMock()
|
||||
conv_session_1.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock()
|
||||
)
|
||||
conv_session_1.commit.return_value = None
|
||||
|
||||
conv_session_2 = MagicMock()
|
||||
conv_session_2.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock()
|
||||
)
|
||||
conv_session_2.commit.return_value = None
|
||||
|
||||
wal_session_1 = MagicMock()
|
||||
wal_session_1.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock()
|
||||
)
|
||||
wal_session_1.commit.return_value = None
|
||||
|
||||
wal_session_2 = MagicMock()
|
||||
wal_session_2.query.side_effect = lambda model: (
|
||||
make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock()
|
||||
)
|
||||
wal_session_2.commit.return_value = None
|
||||
|
||||
session_wrappers = [
|
||||
_session_wrapper_for_no_autoflush(msg_session_1),
|
||||
_session_wrapper_for_no_autoflush(msg_session_2),
|
||||
_session_wrapper_for_no_autoflush(conv_session_1),
|
||||
_session_wrapper_for_no_autoflush(conv_session_2),
|
||||
_session_wrapper_for_no_autoflush(wal_session_1),
|
||||
_session_wrapper_for_no_autoflush(wal_session_2),
|
||||
_sessionmaker_wrapper_for_begin(msg_session_1),
|
||||
_sessionmaker_wrapper_for_begin(msg_session_2),
|
||||
_sessionmaker_wrapper_for_begin(conv_session_1),
|
||||
_sessionmaker_wrapper_for_begin(conv_session_2),
|
||||
_sessionmaker_wrapper_for_begin(wal_session_1),
|
||||
_sessionmaker_wrapper_for_begin(wal_session_2),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0))
|
||||
def fake_sessionmaker(*args, **kwargs):
|
||||
if kwargs.get("autoflush") is False:
|
||||
return session_wrappers.pop(0)
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(service_module, "sessionmaker", fake_sessionmaker)
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
stmt = MagicMock()
|
||||
@ -333,8 +331,6 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -
|
||||
run_repo = MagicMock()
|
||||
run_repo.get_expired_runs_batch.side_effect = [[SimpleNamespace(id="wr-1", to_dict=lambda: {"id": "wr-1"})], []]
|
||||
run_repo.delete_runs_by_ids.return_value = 1
|
||||
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object())
|
||||
monkeypatch.setattr(
|
||||
service_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_node_execution_repository",
|
||||
@ -574,13 +570,18 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte
|
||||
q_empty.limit.return_value = q_empty
|
||||
q_empty.all.return_value = []
|
||||
empty_session.query.return_value = q_empty
|
||||
empty_session.commit.return_value = None
|
||||
session_wrappers = [
|
||||
_session_wrapper_for_no_autoflush(empty_session),
|
||||
_session_wrapper_for_no_autoflush(empty_session),
|
||||
_session_wrapper_for_no_autoflush(empty_session),
|
||||
_sessionmaker_wrapper_for_begin(empty_session),
|
||||
_sessionmaker_wrapper_for_begin(empty_session),
|
||||
_sessionmaker_wrapper_for_begin(empty_session),
|
||||
]
|
||||
monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0))
|
||||
|
||||
def fake_sessionmaker(*args, **kwargs):
|
||||
if kwargs.get("autoflush") is False:
|
||||
return session_wrappers.pop(0)
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(service_module, "sessionmaker", fake_sessionmaker)
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
stmt = MagicMock()
|
||||
@ -606,8 +607,6 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte
|
||||
[],
|
||||
]
|
||||
run_repo.delete_runs_by_ids.return_value = 2
|
||||
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object())
|
||||
monkeypatch.setattr(
|
||||
service_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_node_execution_repository",
|
||||
|
||||
@ -40,7 +40,10 @@ class TestDatasourceProviderService:
|
||||
q returns itself for .filter_by(), .order_by(), .where() so any
|
||||
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
|
||||
"""
|
||||
with patch("services.datasource_provider_service.Session") as mock_cls:
|
||||
with (
|
||||
patch("services.datasource_provider_service.Session") as mock_cls,
|
||||
patch("services.datasource_provider_service.sessionmaker") as mock_sm,
|
||||
):
|
||||
sess = MagicMock(spec=Session)
|
||||
|
||||
q = MagicMock()
|
||||
@ -63,6 +66,8 @@ class TestDatasourceProviderService:
|
||||
|
||||
mock_cls.return_value.__enter__.return_value = sess
|
||||
mock_cls.return_value.no_autoflush.__enter__.return_value = sess
|
||||
mock_sm.return_value.begin.return_value.__enter__.return_value = sess
|
||||
mock_sm.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
yield sess
|
||||
|
||||
@ -266,7 +271,6 @@ class TestDatasourceProviderService:
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}),
|
||||
):
|
||||
service.get_datasource_credentials("t1", "prov", "org/plug")
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user):
|
||||
"""API key credentials with expires_at=-1 skip refresh and return directly."""
|
||||
@ -333,7 +337,6 @@ class TestDatasourceProviderService:
|
||||
p.name = "same"
|
||||
mock_db_session.query().first.return_value = p
|
||||
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
@ -352,7 +355,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.query().count.return_value = 0
|
||||
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
|
||||
assert p.name == "new_name"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# set_default_datasource_provider (lines 277-303)
|
||||
@ -370,7 +372,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.query().first.return_value = target
|
||||
service.set_default_datasource_provider("t1", make_id(), "new-id")
|
||||
assert target.is_default is True
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# get_oauth_encrypter (lines 404-420)
|
||||
@ -460,7 +461,6 @@ class TestDatasourceProviderService:
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
|
||||
"""Conflict on name results in auto-incremented name, not an error."""
|
||||
@ -512,7 +512,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
@ -523,7 +522,6 @@ class TestDatasourceProviderService:
|
||||
service.reauthorize_datasource_oauth_provider(
|
||||
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
|
||||
)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
@ -571,7 +569,6 @@ class TestDatasourceProviderService:
|
||||
):
|
||||
service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"})
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
@ -747,7 +744,6 @@ class TestDatasourceProviderService:
|
||||
# encrypter must have been called with the new secret value
|
||||
self._enc.encrypt_token.assert_called()
|
||||
# commit must be called exactly once
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# remove_datasource_credentials (lines 980-997)
|
||||
@ -758,7 +754,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.scalar.return_value = p
|
||||
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
|
||||
mock_db_session.delete.assert_called_once_with(p)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session):
|
||||
"""No error raised; no delete called when record doesn't exist (lines 994 branch)."""
|
||||
|
||||
@ -1,628 +0,0 @@
|
||||
"""
|
||||
Comprehensive unit tests for RecommendedAppService.
|
||||
|
||||
This test suite provides complete coverage of recommended app operations in Dify,
|
||||
following TDD principles with the Arrange-Act-Assert pattern.
|
||||
|
||||
## Test Coverage
|
||||
|
||||
### 1. Get Recommended Apps and Categories (TestRecommendedAppServiceGetApps)
|
||||
Tests fetching recommended apps with categories:
|
||||
- Successful retrieval with recommended apps
|
||||
- Fallback to builtin when no recommended apps
|
||||
- Different language support
|
||||
- Factory mode selection (remote, builtin, db)
|
||||
- Empty result handling
|
||||
|
||||
### 2. Get Recommend App Detail (TestRecommendedAppServiceGetDetail)
|
||||
Tests fetching individual app details:
|
||||
- Successful app detail retrieval
|
||||
- Different factory modes
|
||||
- App not found scenarios
|
||||
- Language-specific details
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- **Mocking Strategy**: All external dependencies (dify_config, RecommendAppRetrievalFactory)
|
||||
are mocked for fast, isolated unit tests
|
||||
- **Factory Pattern**: Tests verify correct factory selection based on mode
|
||||
- **Fixtures**: Mock objects are configured per test method
|
||||
- **Assertions**: Each test verifies return values and factory method calls
|
||||
|
||||
## Key Concepts
|
||||
|
||||
**Factory Modes:**
|
||||
- remote: Fetch from remote API
|
||||
- builtin: Use built-in templates
|
||||
- db: Fetch from database
|
||||
|
||||
**Fallback Logic:**
|
||||
- If remote/db returns no apps, fallback to builtin en-US templates
|
||||
- Ensures users always see some recommended apps
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
|
||||
class RecommendedAppServiceTestDataFactory:
|
||||
"""
|
||||
Factory for creating test data and mock objects.
|
||||
|
||||
Provides reusable methods to create consistent mock objects for testing
|
||||
recommended app operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_recommended_apps_response(
|
||||
recommended_apps: list[dict] | None = None,
|
||||
categories: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a mock response for recommended apps.
|
||||
|
||||
Args:
|
||||
recommended_apps: List of recommended app dictionaries
|
||||
categories: List of category names
|
||||
|
||||
Returns:
|
||||
Dictionary with recommended_apps and categories
|
||||
"""
|
||||
if recommended_apps is None:
|
||||
recommended_apps = [
|
||||
{
|
||||
"id": "app-1",
|
||||
"name": "Test App 1",
|
||||
"description": "Test description 1",
|
||||
"category": "productivity",
|
||||
},
|
||||
{
|
||||
"id": "app-2",
|
||||
"name": "Test App 2",
|
||||
"description": "Test description 2",
|
||||
"category": "communication",
|
||||
},
|
||||
]
|
||||
if categories is None:
|
||||
categories = ["productivity", "communication", "utilities"]
|
||||
|
||||
return {
|
||||
"recommended_apps": recommended_apps,
|
||||
"categories": categories,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_app_detail_response(
|
||||
app_id: str = "app-123",
|
||||
name: str = "Test App",
|
||||
description: str = "Test description",
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a mock response for app detail.
|
||||
|
||||
Args:
|
||||
app_id: App identifier
|
||||
name: App name
|
||||
description: App description
|
||||
**kwargs: Additional fields
|
||||
|
||||
Returns:
|
||||
Dictionary with app details
|
||||
"""
|
||||
detail = {
|
||||
"id": app_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"category": kwargs.get("category", "productivity"),
|
||||
"icon": kwargs.get("icon", "🚀"),
|
||||
"model_config": kwargs.get("model_config", {}),
|
||||
}
|
||||
detail.update(kwargs)
|
||||
return detail
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory():
|
||||
"""Provide the test data factory to all tests."""
|
||||
return RecommendedAppServiceTestDataFactory
|
||||
|
||||
|
||||
class TestRecommendedAppServiceGetApps:
|
||||
"""Test get_recommended_apps_and_categories operations."""
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_success_with_apps(self, mock_config, mock_factory_class, factory):
|
||||
"""Test successful retrieval of recommended apps when apps are returned."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
|
||||
expected_response = factory.create_recommended_apps_response()
|
||||
|
||||
# Mock factory and retrieval instance
|
||||
mock_retrieval_instance = MagicMock()
|
||||
mock_retrieval_instance.get_recommended_apps_and_categories.return_value = expected_response
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_retrieval_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
assert len(result["recommended_apps"]) == 2
|
||||
assert len(result["categories"]) == 3
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
|
||||
mock_retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class, factory):
|
||||
"""Test fallback to builtin when no recommended apps are returned."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
|
||||
# Remote returns empty recommended_apps
|
||||
empty_response = {"recommended_apps": [], "categories": []}
|
||||
|
||||
# Builtin fallback response
|
||||
builtin_response = factory.create_recommended_apps_response(
|
||||
recommended_apps=[{"id": "builtin-1", "name": "Builtin App", "category": "default"}]
|
||||
)
|
||||
|
||||
# Mock remote retrieval instance (returns empty)
|
||||
mock_remote_instance = MagicMock()
|
||||
mock_remote_instance.get_recommended_apps_and_categories.return_value = empty_response
|
||||
|
||||
mock_remote_factory = MagicMock()
|
||||
mock_remote_factory.return_value = mock_remote_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_remote_factory
|
||||
|
||||
# Mock builtin retrieval instance
|
||||
mock_builtin_instance = MagicMock()
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
|
||||
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN")
|
||||
|
||||
# Assert
|
||||
assert result == builtin_response
|
||||
assert len(result["recommended_apps"]) == 1
|
||||
assert result["recommended_apps"][0]["id"] == "builtin-1"
|
||||
# Verify fallback was called with en-US (hardcoded)
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class, factory):
|
||||
"""Test fallback when recommended_apps key is None."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db"
|
||||
|
||||
# Response with None recommended_apps
|
||||
none_response = {"recommended_apps": None, "categories": ["test"]}
|
||||
|
||||
# Builtin fallback response
|
||||
builtin_response = factory.create_recommended_apps_response()
|
||||
|
||||
# Mock db retrieval instance (returns None)
|
||||
mock_db_instance = MagicMock()
|
||||
mock_db_instance.get_recommended_apps_and_categories.return_value = none_response
|
||||
|
||||
mock_db_factory = MagicMock()
|
||||
mock_db_factory.return_value = mock_db_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_db_factory
|
||||
|
||||
# Mock builtin retrieval instance
|
||||
mock_builtin_instance = MagicMock()
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
|
||||
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
assert result == builtin_response
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_with_different_languages(self, mock_config, mock_factory_class, factory):
|
||||
"""Test retrieval with different language codes."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
|
||||
|
||||
languages = ["en-US", "zh-CN", "ja-JP", "fr-FR"]
|
||||
|
||||
for language in languages:
|
||||
# Create language-specific response
|
||||
lang_response = factory.create_recommended_apps_response(
|
||||
recommended_apps=[{"id": f"app-{language}", "name": f"App {language}", "category": "test"}]
|
||||
)
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommended_apps_and_categories.return_value = lang_response
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories(language)
|
||||
|
||||
# Assert
|
||||
assert result["recommended_apps"][0]["id"] == f"app-{language}"
|
||||
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_uses_correct_factory_mode(self, mock_config, mock_factory_class, factory):
|
||||
"""Test that correct factory is selected based on mode."""
|
||||
# Arrange
|
||||
modes = ["remote", "builtin", "db"]
|
||||
|
||||
for mode in modes:
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
|
||||
|
||||
response = factory.create_recommended_apps_response()
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommended_apps_and_categories.return_value = response
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
|
||||
class TestRecommendedAppServiceGetDetail:
|
||||
"""Test get_recommend_app_detail operations."""
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_success(self, mock_config, mock_factory_class, factory):
|
||||
"""Test successful retrieval of app detail."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
app_id = "app-123"
|
||||
|
||||
expected_detail = factory.create_app_detail_response(
|
||||
app_id=app_id,
|
||||
name="Productivity App",
|
||||
description="A great productivity app",
|
||||
category="productivity",
|
||||
)
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = expected_detail
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result == expected_detail
|
||||
assert result["id"] == app_id
|
||||
assert result["name"] == "Productivity App"
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_with_different_modes(self, mock_config, mock_factory_class, factory):
|
||||
"""Test app detail retrieval with different factory modes."""
|
||||
# Arrange
|
||||
modes = ["remote", "builtin", "db"]
|
||||
app_id = "test-app"
|
||||
|
||||
for mode in modes:
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
|
||||
|
||||
detail = factory.create_app_detail_response(app_id=app_id, name=f"App from {mode}")
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = detail
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result["name"] == f"App from {mode}"
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_returns_none_when_not_found(self, mock_config, mock_factory_class, factory):
|
||||
"""Test that None is returned when app is not found."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
app_id = "nonexistent-app"
|
||||
|
||||
# Mock retrieval instance returning None
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = None
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_returns_empty_dict(self, mock_config, mock_factory_class, factory):
|
||||
"""Test handling of empty dict response."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
|
||||
app_id = "app-empty"
|
||||
|
||||
# Mock retrieval instance returning empty dict
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = {}
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_with_complex_model_config(self, mock_config, mock_factory_class, factory):
|
||||
"""Test app detail with complex model configuration."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
app_id = "complex-app"
|
||||
|
||||
complex_model_config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"parameters": {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
expected_detail = factory.create_app_detail_response(
|
||||
app_id=app_id,
|
||||
name="Complex App",
|
||||
model_config=complex_model_config,
|
||||
workflows=["workflow-1", "workflow-2"],
|
||||
tools=["tool-1", "tool-2", "tool-3"],
|
||||
)
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = expected_detail
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail(app_id))
|
||||
|
||||
# Assert
|
||||
assert result["model_config"] == complex_model_config
|
||||
assert len(result["workflows"]) == 2
|
||||
assert len(result["tools"]) == 3
|
||||
|
||||
|
||||
# === Merged from test_recommended_app_service_additional.py ===
|
||||
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from services import recommended_app_service as service_module
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
|
||||
def _recommendation_detail(result: dict[str, Any] | None) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], result)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_db_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(session=session))
|
||||
|
||||
# Assert
|
||||
return session
|
||||
|
||||
|
||||
def _mock_factory_for_apps(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
*,
|
||||
mode: str,
|
||||
result: dict[str, Any],
|
||||
fallback_result: dict[str, Any] | None = None,
|
||||
) -> tuple[MagicMock, MagicMock]:
|
||||
retrieval_instance = MagicMock()
|
||||
retrieval_instance.get_recommended_apps_and_categories.return_value = result
|
||||
retrieval_factory = MagicMock(return_value=retrieval_instance)
|
||||
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", mode, raising=False)
|
||||
monkeypatch.setattr(
|
||||
service_module.RecommendAppRetrievalFactory,
|
||||
"get_recommend_app_factory",
|
||||
MagicMock(return_value=retrieval_factory),
|
||||
)
|
||||
|
||||
builtin_instance = MagicMock()
|
||||
if fallback_result is not None:
|
||||
builtin_instance.fetch_recommended_apps_from_builtin.return_value = fallback_result
|
||||
monkeypatch.setattr(
|
||||
service_module.RecommendAppRetrievalFactory,
|
||||
"get_buildin_recommend_app_retrieval",
|
||||
MagicMock(return_value=builtin_instance),
|
||||
)
|
||||
return retrieval_instance, builtin_instance
|
||||
|
||||
|
||||
def test_get_recommended_apps_and_categories_should_not_query_trial_table_when_trial_feature_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mocked_db_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
expected = {"recommended_apps": [{"app_id": "app-1"}], "categories": ["all"]}
|
||||
retrieval_instance, builtin_instance = _mock_factory_for_apps(
|
||||
monkeypatch,
|
||||
mode="remote",
|
||||
result=expected,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_system_features",
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=False)),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
|
||||
builtin_instance.fetch_recommended_apps_from_builtin.assert_not_called()
|
||||
mocked_db_session.scalar.assert_not_called()
|
||||
|
||||
|
||||
def test_get_recommended_apps_and_categories_should_fallback_and_enrich_can_trial_when_trial_feature_enabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mocked_db_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
remote_result = {"recommended_apps": [], "categories": []}
|
||||
fallback_result = {"recommended_apps": [{"app_id": "app-1"}, {"app_id": "app-2"}], "categories": ["all"]}
|
||||
_, builtin_instance = _mock_factory_for_apps(
|
||||
monkeypatch,
|
||||
mode="remote",
|
||||
result=remote_result,
|
||||
fallback_result=fallback_result,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_system_features",
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
|
||||
)
|
||||
mocked_db_session.scalar.side_effect = [SimpleNamespace(id="trial-app"), None]
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP")
|
||||
|
||||
# Assert
|
||||
builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
|
||||
assert result["recommended_apps"][0]["can_trial"] is True
|
||||
assert result["recommended_apps"][1]["can_trial"] is False
|
||||
assert mocked_db_session.scalar.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("trial_query_result", "expected_can_trial"),
|
||||
[
|
||||
(SimpleNamespace(id="trial"), True),
|
||||
(None, False),
|
||||
],
|
||||
)
|
||||
def test_get_recommend_app_detail_should_set_can_trial_when_trial_feature_enabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mocked_db_session: MagicMock,
|
||||
trial_query_result: Any,
|
||||
expected_can_trial: bool,
|
||||
) -> None:
|
||||
# Arrange
|
||||
detail = {"id": "app-1", "name": "Test App"}
|
||||
retrieval_instance = MagicMock()
|
||||
retrieval_instance.get_recommend_app_detail.return_value = detail
|
||||
retrieval_factory = MagicMock(return_value=retrieval_instance)
|
||||
monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", "remote", raising=False)
|
||||
monkeypatch.setattr(
|
||||
service_module.RecommendAppRetrievalFactory,
|
||||
"get_recommend_app_factory",
|
||||
MagicMock(return_value=retrieval_factory),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_system_features",
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
|
||||
)
|
||||
mocked_db_session.scalar.return_value = trial_query_result
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], RecommendedAppService.get_recommend_app_detail("app-1"))
|
||||
|
||||
# Assert
|
||||
assert result["id"] == "app-1"
|
||||
assert result["can_trial"] is expected_can_trial
|
||||
mocked_db_session.scalar.assert_called_once()
|
||||
|
||||
|
||||
def test_add_trial_app_record_should_increment_count_when_existing_record_found(
|
||||
mocked_db_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
existing_record = SimpleNamespace(count=3)
|
||||
mocked_db_session.scalar.return_value = existing_record
|
||||
|
||||
# Act
|
||||
RecommendedAppService.add_trial_app_record("app-1", "account-1")
|
||||
|
||||
# Assert
|
||||
assert existing_record.count == 4
|
||||
mocked_db_session.scalar.assert_called_once()
|
||||
mocked_db_session.commit.assert_called_once()
|
||||
mocked_db_session.add.assert_not_called()
|
||||
|
||||
|
||||
def test_add_trial_app_record_should_create_new_record_when_no_existing_record(
|
||||
mocked_db_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocked_db_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
RecommendedAppService.add_trial_app_record("app-2", "account-2")
|
||||
|
||||
# Assert
|
||||
mocked_db_session.scalar.assert_called_once()
|
||||
mocked_db_session.add.assert_called_once()
|
||||
added = mocked_db_session.add.call_args.args[0]
|
||||
assert added.app_id == "app-2"
|
||||
assert added.account_id == "account-2"
|
||||
assert added.count == 1
|
||||
mocked_db_session.commit.assert_called_once()
|
||||
@ -63,6 +63,12 @@ def mock_session(mocker: MockerFixture) -> MagicMock:
|
||||
mock_session_cm.__enter__.return_value = mock_session_instance
|
||||
mock_session_cm.__exit__.return_value = False
|
||||
mocker.patch("services.trigger.trigger_provider_service.Session", return_value=mock_session_cm)
|
||||
mock_begin_cm = MagicMock()
|
||||
mock_begin_cm.__enter__.return_value = mock_session_instance
|
||||
mock_begin_cm.__exit__.return_value = False
|
||||
mock_sessionmaker_instance = MagicMock()
|
||||
mock_sessionmaker_instance.begin.return_value = mock_begin_cm
|
||||
mocker.patch("services.trigger.trigger_provider_service.sessionmaker", return_value=mock_sessionmaker_instance)
|
||||
return mock_session_instance
|
||||
|
||||
|
||||
@ -212,7 +218,6 @@ def test_add_trigger_subscription_should_create_subscription_successfully_for_ap
|
||||
# Assert
|
||||
assert result["result"] == "success"
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorized_type(
|
||||
@ -406,7 +411,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache(
|
||||
assert subscription.credentials == {"api_key": "new-key"}
|
||||
assert subscription.credential_expires_at == 100
|
||||
assert subscription.expires_at == 200
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
mock_delete_cache.assert_called_once()
|
||||
|
||||
|
||||
@ -593,7 +598,7 @@ def test_refresh_oauth_token_should_refresh_and_persist_new_credentials(
|
||||
assert result == {"result": "success", "expires_at": 12345}
|
||||
assert subscription.credentials == {"access_token": "new"}
|
||||
assert subscription.credential_expires_at == 12345
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
cache.delete.assert_called_once()
|
||||
|
||||
|
||||
@ -664,7 +669,7 @@ def test_refresh_subscription_should_refresh_and_persist_properties(
|
||||
assert result == {"result": "success", "expires_at": 999}
|
||||
assert subscription.properties == {"p": "new-enc"}
|
||||
assert subscription.expires_at == 999
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
prop_cache.delete.assert_called_once()
|
||||
|
||||
|
||||
@ -838,7 +843,6 @@ def test_save_custom_oauth_client_params_should_create_record_and_clear_params_w
|
||||
assert fake_model.encrypted_oauth_params == "{}"
|
||||
assert fake_model.enabled is True
|
||||
mock_session.add.assert_called_once_with(fake_model)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_cache(
|
||||
@ -870,7 +874,6 @@ def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_c
|
||||
assert result == {"result": "success"}
|
||||
assert json.loads(custom_client.encrypted_oauth_params) == {"client_id": "new-id"}
|
||||
cache.delete.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_custom_oauth_client_params_should_return_empty_when_record_missing(
|
||||
@ -921,7 +924,6 @@ def test_delete_custom_oauth_client_params_should_delete_record_and_commit(
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exists", [True, False])
|
||||
|
||||
@ -617,6 +617,20 @@ class _SessionContext:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionmakerContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def begin(self) -> "_SessionmakerContext":
|
||||
return self
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return Flask(__name__)
|
||||
@ -625,6 +639,7 @@ def flask_app() -> Flask:
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
|
||||
|
||||
|
||||
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
|
||||
@ -1240,7 +1255,6 @@ def test_sync_webhook_relationships_should_create_missing_records_and_delete_sta
|
||||
# Assert
|
||||
assert len(fake_session.added) == 1
|
||||
assert len(fake_session.deleted) == 1
|
||||
assert fake_session.commit_count == 2
|
||||
redis_set_mock.assert_called_once()
|
||||
redis_delete_mock.assert_called_once()
|
||||
lock.release.assert_called_once()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user