From 0cc27dd4010d04e80812c7701d2efdee172e1780 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Tue, 23 Jun 2026 03:38:24 +0800 Subject: [PATCH] chore: not use request.scoped session (#37421) Co-authored-by: WH-2099 --- api/core/app/apps/advanced_chat/app_runner.py | 10 +- api/core/app/apps/agent_chat/app_runner.py | 21 ++- api/core/app/apps/chat/app_runner.py | 8 +- api/core/app/apps/completion/app_runner.py | 8 +- api/core/app/apps/pipeline/pipeline_runner.py | 51 +++--- api/core/plugin/backwards_invocation/app.py | 97 +++++++++-- api/models/model.py | 49 +++--- .../test_app_runner_conversation_variables.py | 39 ++--- .../test_app_runner_input_moderation.py | 73 ++++++-- .../agent_chat/test_agent_chat_app_runner.py | 60 +++---- .../chat/test_app_generator_and_runner.py | 96 +++++++++-- .../app/apps/completion/test_app_runner.py | 91 +++++++--- .../app/apps/pipeline/test_pipeline_runner.py | 52 ++++-- .../plugin/test_backwards_invocation_app.py | 163 +++++++++++++----- .../unit_tests/models/test_app_models.py | 68 +++++++- 15 files changed, 633 insertions(+), 253 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 256521ab654..67397965384 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence from typing import Any, cast from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager @@ -22,7 +22,7 @@ from core.app.entities.queue_entities import ( from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.db.session_factory import session_factory +from core.db.session_factory import create_session, session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository @@ -107,7 +107,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow_execution_id=self.application_generate_entity.workflow_run_id, ) - with Session(db.engine, expire_on_commit=False) as session: + with create_session() as session: app_record = session.scalar(select(App).where(App.id == app_config.app_id)) if not app_record: @@ -204,6 +204,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): trace_session_id=self.application_generate_entity.extras.get("trace_session_id"), ) + # Release the Flask scoped session before workflow execution so a checked-out DB connection + # is not held for the lifetime of the graph run. db.session.close() # RUN WORKFLOW @@ -368,7 +370,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): :return: List of conversation variables ready for use """ - with sessionmaker(bind=db.engine).begin() as session: + with create_session() as session, session.begin(): existing_variables = self._load_existing_conversation_variables(session) if not existing_variables: diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index cae0eee0df0..5f9c75129b5 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -12,10 +12,10 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity from core.app.entities.queue_entities import QueueAnnotationReplyEvent +from core.db.session_factory import create_session from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError -from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel @@ -47,7 +47,10 @@ class AgentChatAppRunner(AppRunner): app_config = application_generate_entity.app_config app_config = cast(AgentChatAppConfig, app_config) app_stmt = select(App).where(App.id == app_config.app_id) - app_record = db.session.scalar(app_stmt) + with create_session() as session: + app_record = session.scalar(app_stmt) + if app_record: + session.expunge(app_record) if not app_record: raise ValueError("App not found") @@ -185,14 +188,18 @@ class AgentChatAppRunner(AppRunner): if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING conversation_stmt = select(Conversation).where(Conversation.id == conversation.id) - conversation_result = db.session.scalar(conversation_stmt) - if conversation_result is None: - raise ValueError("Conversation not found") msg_stmt = select(Message).where(Message.id == message.id) - message_result = db.session.scalar(msg_stmt) + with create_session() as session: + conversation_result = session.scalar(conversation_stmt) + if conversation_result is None: + raise ValueError("Conversation not found") + + message_result = session.scalar(msg_stmt) + if message_result is not None: + session.expunge(message_result) + session.expunge(conversation_result) if message_result is None: raise ValueError("Message not found") - db.session.close() runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner] # start agent runner diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 077c5239f39..9c2eaf60dc7 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.db.session_factory import create_session from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError @@ -46,7 +47,10 @@ class ChatAppRunner(AppRunner): app_config = application_generate_entity.app_config app_config = cast(ChatAppConfig, app_config) stmt = select(App).where(App.id == app_config.app_id) - app_record = db.session.scalar(stmt) + with create_session() as session: + app_record = session.scalar(stmt) + if app_record: + session.expunge(app_record) if not app_record: raise ValueError("App not found") @@ -216,6 +220,8 @@ class ChatAppRunner(AppRunner): model=application_generate_entity.model_conf.model, ) + # Release the Flask scoped session before LLM streaming so a checked-out DB connection + # is not held for the lifetime of the provider response. db.session.close() invoke_result = model_instance.invoke_llm( diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 6bb1ecdcb19..38ef672ae22 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import ( CompletionAppGenerateEntity, ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.db.session_factory import create_session from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -39,7 +40,10 @@ class CompletionAppRunner(AppRunner): app_config = application_generate_entity.app_config app_config = cast(CompletionAppConfig, app_config) stmt = select(App).where(App.id == app_config.app_id) - app_record = db.session.scalar(stmt) + with create_session() as session: + app_record = session.scalar(stmt) + if app_record: + session.expunge(app_record) if not app_record: raise ValueError("App not found") @@ -174,6 +178,8 @@ class CompletionAppRunner(AppRunner): model=application_generate_entity.model_conf.model, ) + # Release the Flask scoped session before LLM streaming so a checked-out DB connection + # is not held for the lifetime of the provider response. db.session.close() invoke_result = model_instance.invoke_llm( diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 2ee0ae27ebc..3ad0990cbb4 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -3,6 +3,7 @@ import time from typing import cast from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig @@ -14,12 +15,12 @@ from core.app.entities.app_invoke_entities import ( build_dify_run_context, ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.db.session_factory import create_session from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from core.workflow.node_factory import DifyGraphInitContext, DifyNodeFactory, get_default_root_node_id from core.workflow.system_variables import build_bootstrap_variables, build_system_variables from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from extensions.ext_database import db from graphon.enums import WorkflowType from graphon.graph import Graph from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent @@ -83,22 +84,24 @@ class PipelineRunner(WorkflowBasedAppRunner): user_from = self._resolve_user_from(invoke_from) user_id = None - if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: - end_user = db.session.get(EndUser, self.application_generate_entity.user_id) - if end_user: - user_id = end_user.session_id - else: - user_id = self.application_generate_entity.user_id + with create_session() as session: + if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: + end_user = session.get(EndUser, self.application_generate_entity.user_id) + if end_user: + user_id = end_user.session_id + else: + user_id = self.application_generate_entity.user_id - pipeline = db.session.get(Pipeline, app_config.app_id) - if not pipeline: - raise ValueError("Pipeline not found") + pipeline = session.get(Pipeline, app_config.app_id) + if not pipeline: + raise ValueError("Pipeline not found") - workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id) - if not workflow: - raise ValueError("Workflow not initialized") + workflow = self.get_workflow(session=session, pipeline=pipeline, workflow_id=app_config.workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") - db.session.close() + session.expunge(pipeline) + session.expunge(workflow) # if only single iteration run is requested if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: @@ -208,12 +211,12 @@ class PipelineRunner(WorkflowBasedAppRunner): ) self._handle_event(workflow_entry, event) - def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None: + def get_workflow(self, session: Session, pipeline: Pipeline, workflow_id: str) -> Workflow | None: """ Get workflow """ # fetch workflow by workflow_id - workflow = db.session.scalar( + workflow = session.scalar( select(Workflow) .where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id) .limit(1) @@ -298,11 +301,11 @@ class PipelineRunner(WorkflowBasedAppRunner): """ if isinstance(event, GraphRunFailedEvent): if document_id and dataset_id: - document = db.session.scalar( - select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) - ) - if document: - document.indexing_status = "error" - document.error = event.error or "Unknown error" - db.session.add(document) - db.session.commit() + with create_session() as session, session.begin(): + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) + ) + if document: + document.indexing_status = "error" + document.error = event.error or "Unknown error" + session.add(document) diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index c76cb865c31..d022b002f72 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -3,7 +3,6 @@ from collections.abc import Generator, Mapping from typing import Any, cast from sqlalchemy import select -from sqlalchemy.orm import Session from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator @@ -13,10 +12,19 @@ from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig +from core.db.session_factory import create_session from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from extensions.ext_database import db -from models import Account -from models.model import App, AppMode, EndUser +from models import Account, TenantAccountJoin +from models.model import ( + App, + AppMode, + AppModelConfig, + AppModelConfigDict, + EndUser, + load_annotation_reply_config, +) +from models.workflow import Workflow from services.end_user_service import EndUserService @@ -30,18 +38,18 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """Retrieve app parameters.""" if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - workflow = app.workflow + workflow = cls._get_workflow(app) if workflow is None: raise ValueError("unexpected app type") features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: - app_model_config = app.app_model_config - if app_model_config is None: + app_model_config_dict = cls._get_app_model_config_dict(app) + if app_model_config_dict is None: raise ValueError("unexpected app type") - features_dict = cast(dict[str, Any], app_model_config.to_dict()) + features_dict = cast(dict[str, Any], app_model_config_dict) user_input_form = features_dict.get("user_input_form", []) @@ -68,7 +76,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): if not user_id: user = EndUserService.get_or_create_end_user(app) else: - user = cls._get_user(user_id) + user = cls._get_user(user_id, app) conversation_id = conversation_id or "" @@ -79,7 +87,10 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files) case AppMode.WORKFLOW: - return cls.invoke_workflow_app(app, user, stream, inputs, files) + workflow = cls._get_workflow(app) + if not workflow: + raise ValueError("unexpected app type") + return cls.invoke_workflow_app(app, workflow, user, stream, inputs, files) case AppMode.COMPLETION: return cls.invoke_completion_app(app, user, stream, inputs, files) case _: @@ -101,7 +112,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """ match app.mode: case AppMode.ADVANCED_CHAT: - workflow = app.workflow + workflow = cls._get_workflow(app) if not workflow: raise ValueError("unexpected app type") @@ -158,6 +169,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): def invoke_workflow_app( cls, app: App, + workflow: Workflow, user: EndUser | Account, stream: bool, inputs: Mapping, @@ -166,10 +178,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """ invoke workflow app """ - workflow = app.workflow - if not workflow: - raise ValueError("unexpected app type") - pause_config = PauseStateLayerConfig( session_factory=db.engine, state_owner_user_id=workflow.created_by, @@ -207,16 +215,26 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): ) @classmethod - def _get_user(cls, user_id: str) -> EndUser | Account: + def _get_user(cls, user_id: str, app: App) -> EndUser | Account: """ get the user by user id """ - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(EndUser).where(EndUser.id == user_id) + with create_session() as session: + stmt = select(EndUser).where( + EndUser.id == user_id, + EndUser.tenant_id == app.tenant_id, + EndUser.app_id == app.id, + ) user = session.scalar(stmt) if not user: - stmt = select(Account).where(Account.id == user_id) + stmt = select(Account).where( + Account.id == user_id, + Account.id == TenantAccountJoin.account_id, + TenantAccountJoin.tenant_id == app.tenant_id, + ) user = session.scalar(stmt) + if user: + session.expunge(user) if not user: raise ValueError("user not found") @@ -229,7 +247,10 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): get app """ try: - app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1)) + with create_session() as session: + app = session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1)) + if app: + session.expunge(app) except Exception: raise ValueError("app not found") @@ -237,3 +258,41 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): raise ValueError("app not found") return app + + @classmethod + def _get_workflow(cls, app: App) -> Workflow | None: + """ + get workflow without relying on App.workflow's request-scoped session property + """ + if not app.workflow_id: + return None + + with create_session() as session: + workflow = session.scalar( + select(Workflow) + .where(Workflow.id == app.workflow_id, Workflow.tenant_id == app.tenant_id, Workflow.app_id == app.id) + .limit(1) + ) + if workflow: + session.expunge(workflow) + return workflow + + @classmethod + def _get_app_model_config_dict(cls, app: App) -> AppModelConfigDict | None: + """ + get app model config features without relying on request-scoped session-backed model properties + """ + if not app.app_model_config_id: + return None + + with create_session() as session: + app_model_config = session.scalar( + select(AppModelConfig) + .where(AppModelConfig.id == app.app_model_config_id, AppModelConfig.app_id == app.id) + .limit(1) + ) + if app_model_config is None: + return None + + annotation_reply = load_annotation_reply_config(session, app_model_config.app_id) + return app_model_config.to_dict(annotation_reply=annotation_reply) diff --git a/api/models/model.py b/api/models/model.py index 4c73385f3aa..947cbf6fe4a 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -774,26 +774,7 @@ class AppModelConfig(TypeBase): @property def annotation_reply_dict(self) -> AnnotationReplyConfig: - annotation_setting = db.session.scalar( - select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id) - ) - if annotation_setting: - collection_binding_detail = annotation_setting.collection_binding_detail - if not collection_binding_detail: - raise ValueError("Collection binding detail not found") - - return { - "id": annotation_setting.id, - "enabled": True, - "score_threshold": annotation_setting.score_threshold, - "embedding_model": { - "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name, - }, - } - - else: - return {"enabled": False} + return load_annotation_reply_config(db.session(), self.app_id) @property def more_like_this_dict(self) -> EnabledConfig: @@ -864,7 +845,7 @@ class AppModelConfig(TypeBase): }, ) - def to_dict(self) -> AppModelConfigDict: + def to_dict(self, *, annotation_reply: AnnotationReplyConfig | None = None) -> AppModelConfigDict: return { "opening_statement": self.opening_statement, "suggested_questions": self.suggested_questions_list, @@ -872,7 +853,7 @@ class AppModelConfig(TypeBase): "speech_to_text": self.speech_to_text_dict, "text_to_speech": self.text_to_speech_dict, "retriever_resource": self.retriever_resource_dict, - "annotation_reply": self.annotation_reply_dict, + "annotation_reply": annotation_reply if annotation_reply is not None else self.annotation_reply_dict, "more_like_this": self.more_like_this_dict, "sensitive_word_avoidance": self.sensitive_word_avoidance_dict, "external_data_tools": self.external_data_tools_list, @@ -2038,6 +2019,30 @@ class AppAnnotationSetting(TypeBase): ) +def load_annotation_reply_config(session: Session, app_id: str) -> AnnotationReplyConfig: + annotation_setting = session.scalar(select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id)) + if annotation_setting is None: + return {"enabled": False} + + from .dataset import DatasetCollectionBinding + + collection_binding_detail = session.scalar( + select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == annotation_setting.collection_binding_id) + ) + if collection_binding_detail is None: + raise ValueError("Collection binding detail not found") + + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name, + }, + } + + class OperationLog(TypeBase): __tablename__ = "operation_logs" __table_args__ = ( diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index ac9eddb680f..1970e5c1522 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -25,6 +25,15 @@ MINIMAL_GRAPH = { } +def _patch_create_session(mock_session: MagicMock): + session_context = MagicMock() + session_context.__enter__.return_value = mock_session + session_context.__exit__.return_value = False + mock_session.begin.return_value.__enter__.return_value = mock_session + mock_session.begin.return_value.__exit__.return_value = False + return patch("core.app.apps.advanced_chat.app_runner.create_session", return_value=session_context) + + class TestAdvancedChatAppRunnerConversationVariables: """Test that AdvancedChatAppRunner correctly handles conversation variables.""" @@ -135,10 +144,8 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( - patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, - patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + _patch_create_session(mock_session), patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, - patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, patch.object(runner, "_init_graph") as mock_init_graph, patch.object( runner, @@ -151,12 +158,6 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client, patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): - # Setup mocks - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) - mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.engine = MagicMock() - # Mock GraphRuntimeState to accept the variable pool mock_graph_runtime_state_class.return_value = MagicMock() @@ -281,10 +282,8 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( - patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, - patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + _patch_create_session(mock_session), patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, - patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, patch.object(runner, "_init_graph") as mock_init_graph, patch.object( runner, @@ -298,12 +297,6 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client, patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): - # Setup mocks - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) - mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.engine = MagicMock() - # Mock ConversationVariable.from_variable to return mock objects mock_conv_vars = [] for var in workflow_vars: @@ -434,10 +427,8 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( - patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, - patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + _patch_create_session(mock_session), patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, - patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, patch.object(runner, "_init_graph") as mock_init_graph, patch.object( runner, @@ -450,12 +441,6 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client, patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): - # Setup mocks - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) - mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.engine = MagicMock() - # Mock GraphRuntimeState to accept the variable pool mock_graph_runtime_state_class.return_value = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py index 2076e42e9f9..2e3f7645c73 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -3,6 +3,7 @@ from uuid import uuid4 import pytest +import core.app.apps.advanced_chat.app_runner as module from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.queue_entities import QueueStopEvent @@ -85,27 +86,24 @@ def build_runner(): def _patch_common_run_deps(runner: AdvancedChatAppRunner): """Context manager that patches common heavy deps used by run().""" + # create_session() returns a context manager whose body yields a session that + # supports both scalar() (app record lookup) and begin()/scalars().all() + # (conversation variable initialization). + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock() + mock_session.scalars.return_value.all.return_value = [] + + session_context = MagicMock() + session_context.__enter__.return_value = mock_session + session_context.__exit__.return_value = False + mock_session.begin.return_value.__enter__.return_value = mock_session + mock_session.begin.return_value.__exit__.return_value = False + return patch.multiple( "core.app.apps.advanced_chat.app_runner", - Session=MagicMock( - return_value=MagicMock( - __enter__=lambda s: s, - __exit__=lambda *a, **k: False, - scalar=lambda *a, **k: MagicMock(), - ), - ), - sessionmaker=MagicMock( - return_value=MagicMock( - begin=MagicMock( - return_value=MagicMock( - __enter__=lambda s: MagicMock(scalars=MagicMock(return_value=MagicMock(all=lambda: []))), - __exit__=lambda *a, **k: False, - ), - ), - ), - ), + create_session=MagicMock(return_value=session_context), select=MagicMock(), - db=MagicMock(engine=MagicMock()), + session_factory=MagicMock(get_session_maker=MagicMock(return_value=MagicMock())), RedisChannel=MagicMock(), redis_client=MagicMock(), WorkflowEntry=MagicMock(**{"return_value.run.return_value": iter([])}), @@ -192,3 +190,42 @@ def test_run_returns_early_when_direct_output_via_handle_input_moderation(build_ # Ensure no further steps executed mock_anno.assert_not_called() mock_init_graph.assert_not_called() + + +def test_run_closes_scoped_session_before_workflow_run(build_runner): + runner = build_runner + events = [] + + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock() + session_context = MagicMock() + session_context.__enter__.return_value = mock_session + session_context.__exit__.return_value = False + + workflow_entry = MagicMock() + + def run_workflow(): + events.append("run") + return iter([]) + + workflow_entry.run.side_effect = run_workflow + + with ( + patch.object(module, "create_session", return_value=session_context), + patch.object(module, "session_factory", MagicMock(get_session_maker=MagicMock(return_value=MagicMock()))), + patch.object(module, "RedisChannel"), + patch.object(module, "redis_client"), + patch.object(module, "WorkflowEntry", return_value=workflow_entry), + patch.object(module.db.session, "close", side_effect=lambda: events.append("close")), + patch.object( + runner, + "handle_input_moderation", + return_value=(False, runner.application_generate_entity.inputs, runner.application_generate_entity.query), + ), + patch.object(runner, "handle_annotation_reply", return_value=False), + patch.object(runner, "_initialize_conversation_variables", return_value=[]), + patch.object(runner, "_init_graph", return_value=MagicMock()), + ): + runner.run() + + assert events == ["close", "run"] diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py index d7988cbf74d..af2fb22ec78 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -13,12 +13,24 @@ def runner(): return AgentChatAppRunner() +def patch_create_session(mocker: MockerFixture, *, return_value=None, side_effect=None): + session = mocker.MagicMock() + if side_effect is not None: + session.scalar.side_effect = side_effect + else: + session.scalar.return_value = return_value + session_context = mocker.MagicMock() + session_context.__enter__.return_value = session + mocker.patch("core.app.apps.agent_chat.app_runner.create_session", return_value=session_context) + return session + + class TestAgentChatAppRunnerRun: def test_run_app_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture): app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock()) generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True) - mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None) + patch_create_session(mocker, return_value=None) with pytest.raises(ValueError): runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) @@ -37,7 +49,7 @@ class TestAgentChatAppRunnerRun: conversation_id=None, ) - mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + patch_create_session(mocker, return_value=app_record) mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad")) mocker.patch.object(runner, "direct_output") @@ -62,7 +74,7 @@ class TestAgentChatAppRunnerRun: invoke_from=mocker.MagicMock(), ) - mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + patch_create_session(mocker, return_value=app_record) mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) annotation = mocker.MagicMock(id="anno", content="answer") @@ -91,7 +103,7 @@ class TestAgentChatAppRunnerRun: user_id="user", ) - mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + patch_create_session(mocker, return_value=app_record) mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) @@ -121,7 +133,7 @@ class TestAgentChatAppRunnerRun: user_id="user", ) - mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + patch_create_session(mocker, return_value=app_record) mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) @@ -163,7 +175,7 @@ class TestAgentChatAppRunnerRun: user_id="user", ) - mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + patch_create_session(mocker, return_value=app_record) mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) @@ -179,10 +191,7 @@ class TestAgentChatAppRunnerRun: conversation = mocker.MagicMock(id="conv") message = mocker.MagicMock(id="msg") - mocker.patch( - "core.app.apps.agent_chat.app_runner.db.session.scalar", - side_effect=[app_record, conversation, message], - ) + patch_create_session(mocker, side_effect=[app_record, conversation, message]) runner_cls = mocker.MagicMock() mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls) @@ -219,7 +228,7 @@ class TestAgentChatAppRunnerRun: user_id="user", ) - mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + patch_create_session(mocker, return_value=app_record) mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) @@ -235,10 +244,7 @@ class TestAgentChatAppRunnerRun: conversation = mocker.MagicMock(id="conv") message = mocker.MagicMock(id="msg") - mocker.patch( - "core.app.apps.agent_chat.app_runner.db.session.scalar", - side_effect=[app_record, conversation, message], - ) + patch_create_session(mocker, side_effect=[app_record, conversation, message]) with pytest.raises(ValueError): runner.run(generate_entity, mocker.MagicMock(), conversation, message) @@ -267,7 +273,7 @@ class TestAgentChatAppRunnerRun: user_id="user", ) - mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + patch_create_session(mocker, return_value=app_record) mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) @@ -283,10 +289,7 @@ class TestAgentChatAppRunnerRun: conversation = mocker.MagicMock(id="conv") message = mocker.MagicMock(id="msg") - mocker.patch( - "core.app.apps.agent_chat.app_runner.db.session.scalar", - side_effect=[app_record, conversation, message], - ) + patch_create_session(mocker, side_effect=[app_record, conversation, message]) runner_cls = mocker.MagicMock() mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls) @@ -323,10 +326,7 @@ class TestAgentChatAppRunnerRun: user_id="user", ) - mocker.patch( - "core.app.apps.agent_chat.app_runner.db.session.scalar", - side_effect=[app_record, None], - ) + patch_create_session(mocker, side_effect=[app_record, None]) mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) @@ -357,10 +357,7 @@ class TestAgentChatAppRunnerRun: user_id="user", ) - mocker.patch( - "core.app.apps.agent_chat.app_runner.db.session.scalar", - side_effect=[app_record, mocker.MagicMock(id="conv"), None], - ) + patch_create_session(mocker, side_effect=[app_record, mocker.MagicMock(id="conv"), None]) mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) @@ -391,7 +388,7 @@ class TestAgentChatAppRunnerRun: user_id="user", ) - mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + patch_create_session(mocker, return_value=app_record) mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) @@ -407,10 +404,7 @@ class TestAgentChatAppRunnerRun: conversation = mocker.MagicMock(id="conv") message = mocker.MagicMock(id="msg") - mocker.patch( - "core.app.apps.agent_chat.app_runner.db.session.scalar", - side_effect=[app_record, conversation, message], - ) + patch_create_session(mocker, side_effect=[app_record, conversation, message]) with pytest.raises(ValueError): runner.run(generate_entity, mocker.MagicMock(), conversation, message) diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py index 6f104a5eaa1..23334dbe67c 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -1,5 +1,6 @@ +from contextlib import contextmanager from types import SimpleNamespace -from unittest.mock import Mock, patch +from unittest.mock import ANY, MagicMock, Mock, patch import pytest @@ -29,6 +30,19 @@ class DummyQueueManager: self.published.append((event, pub_from)) +@contextmanager +def patched_create_session(*, return_value=None, side_effect=None): + session = MagicMock() + if side_effect is not None: + session.scalar.side_effect = side_effect + else: + session.scalar.return_value = return_value + session_context = MagicMock() + session_context.__enter__.return_value = session + with patch("core.app.apps.chat.app_runner.create_session", return_value=session_context): + yield session + + class TestChatAppGenerator: def test_generate_requires_query(self): generator = ChatAppGenerator() @@ -167,7 +181,7 @@ class TestChatAppRunner: invoke_from=InvokeFrom.SERVICE_API, ) - with patch("core.app.apps.chat.app_runner.db.session.scalar", return_value=None): + with patched_create_session(return_value=None): with pytest.raises(ValueError): runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1")) @@ -195,10 +209,7 @@ class TestChatAppRunner: ) with ( - patch( - "core.app.apps.chat.app_runner.db.session.scalar", - return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"), - ), + patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")), patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])), patch.object(ChatAppRunner, "moderation_for_inputs", side_effect=ModerationError("blocked")), patch.object(ChatAppRunner, "direct_output") as mock_direct, @@ -233,10 +244,7 @@ class TestChatAppRunner: annotation = SimpleNamespace(id="ann-1", content="answer") with ( - patch( - "core.app.apps.chat.app_runner.db.session.scalar", - return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"), - ), + patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")), patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])), patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")), patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=annotation), @@ -272,13 +280,73 @@ class TestChatAppRunner: ) with ( - patch( - "core.app.apps.chat.app_runner.db.session.scalar", - return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"), - ), + patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")), patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])), patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")), patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None), patch.object(ChatAppRunner, "check_hosting_moderation", return_value=True), ): runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1")) + + def test_run_closes_scoped_session_before_stream_consumption(self): + runner = ChatAppRunner() + app_config = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + prompt_template=None, + external_data_variables=[], + dataset=None, + additional_features=None, + ) + app_generate_entity = DummyGenerateEntity( + app_config=app_config, + model_conf=SimpleNamespace(provider_model_bundle=None, model="model-1", parameters={}), + inputs={}, + query="hi", + files=[], + file_upload_config=None, + conversation_id=None, + stream=True, + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + events = [] + queue_manager = DummyQueueManager() + model_instance = MagicMock() + + def invoke_stream(): + events.append("first-chunk") + yield "chunk" + + def invoke_llm(**kwargs): + events.append("invoke") + return invoke_stream() + + with ( + patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")), + patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])), + patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")), + patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None), + patch.object(ChatAppRunner, "check_hosting_moderation", return_value=False), + patch.object(ChatAppRunner, "recalc_llm_max_tokens"), + patch.object( + ChatAppRunner, + "_handle_invoke_result", + side_effect=lambda invoke_result, **kwargs: list(invoke_result), + ) as mock_handle, + patch("core.app.apps.chat.app_runner.ModelInstance", return_value=model_instance), + patch("core.app.apps.chat.app_runner.db.session.close", side_effect=lambda: events.append("close")), + ): + model_instance.invoke_llm.side_effect = invoke_llm + runner.run(app_generate_entity, queue_manager, SimpleNamespace(), SimpleNamespace(id="m1")) + + assert events == ["close", "invoke", "first-chunk"] + mock_handle.assert_called_once_with( + invoke_result=ANY, + queue_manager=queue_manager, + stream=True, + message_id="m1", + user_id="user-1", + tenant_id="tenant-1", + ) diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py index 8dcf6e91935..2fdb197852a 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -1,5 +1,6 @@ +from contextlib import contextmanager from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import ANY, MagicMock, patch import pytest from pytest_mock import MockerFixture @@ -47,25 +48,28 @@ def _build_generate_entity(app_config, file_upload_config=None): ) +@contextmanager +def patched_create_session(*, return_value=None): + session = MagicMock() + session.scalar.return_value = return_value + session_context = MagicMock() + session_context.__enter__.return_value = session + with patch.object(module, "create_session", return_value=session_context): + yield session + + class TestCompletionAppRunner: def test_run_app_not_found(self, runner, mocker: MockerFixture): - session = mocker.MagicMock() - session.scalar.return_value = None - mocker.patch.object(module.db, "session", session) - app_config = _build_app_config() app_generate_entity = _build_generate_entity(app_config) - with pytest.raises(ValueError): - runner.run(app_generate_entity, MagicMock(), MagicMock()) + with patched_create_session(return_value=None): + with pytest.raises(ValueError): + runner.run(app_generate_entity, MagicMock(), MagicMock()) def test_run_moderation_error_outputs_direct(self, runner, mocker: MockerFixture): app_record = MagicMock(id="app1", tenant_id="tenant") - session = mocker.MagicMock() - session.scalar.return_value = app_record - mocker.patch.object(module.db, "session", session) - app_config = _build_app_config() app_generate_entity = _build_generate_entity(app_config) @@ -74,7 +78,8 @@ class TestCompletionAppRunner: runner.direct_output = MagicMock() runner._handle_invoke_result = MagicMock() - runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) + with patched_create_session(return_value=app_record): + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) runner.direct_output.assert_called_once() runner._handle_invoke_result.assert_not_called() @@ -82,10 +87,6 @@ class TestCompletionAppRunner: def test_run_hosting_moderation_stops(self, runner, mocker: MockerFixture): app_record = MagicMock(id="app1", tenant_id="tenant") - session = mocker.MagicMock() - session.scalar.return_value = app_record - mocker.patch.object(module.db, "session", session) - app_config = _build_app_config() app_generate_entity = _build_generate_entity(app_config) @@ -94,18 +95,14 @@ class TestCompletionAppRunner: runner.check_hosting_moderation = MagicMock(return_value=True) runner._handle_invoke_result = MagicMock() - runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) + with patched_create_session(return_value=app_record): + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) runner._handle_invoke_result.assert_not_called() def test_run_dataset_and_external_tools_flow(self, runner, mocker: MockerFixture): app_record = MagicMock(id="app1", tenant_id="tenant") - session = mocker.MagicMock() - session.scalar.return_value = app_record - session.close = MagicMock() - mocker.patch.object(module.db, "session", session) - retrieve_config = MagicMock(query_variable="qvar") dataset_config = MagicMock(dataset_ids=["ds"], retrieve_config=retrieve_config) additional_features = MagicMock(show_retrieve_source=True) @@ -135,19 +132,56 @@ class TestCompletionAppRunner: model_instance.invoke_llm.return_value = "invoke_result" mocker.patch.object(module, "ModelInstance", return_value=model_instance) - runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant")) + with patched_create_session(return_value=app_record): + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant")) dataset_retrieval.retrieve.assert_called_once() assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input" runner._handle_invoke_result.assert_called_once() + def test_run_closes_scoped_session_before_stream_consumption(self, runner, mocker: MockerFixture): + app_record = MagicMock(id="app1", tenant_id="tenant") + app_config = _build_app_config() + app_generate_entity = _build_generate_entity(app_config) + queue_manager = MagicMock() + + events = [] + runner.organize_prompt_messages = MagicMock(return_value=([], None)) + runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query")) + runner.check_hosting_moderation = MagicMock(return_value=False) + runner.recalc_llm_max_tokens = MagicMock() + runner._handle_invoke_result = MagicMock(side_effect=lambda invoke_result, **kwargs: list(invoke_result)) + + model_instance = MagicMock() + + def invoke_stream(): + events.append("first-chunk") + yield "chunk" + + def invoke_llm(**kwargs): + events.append("invoke") + return invoke_stream() + + model_instance.invoke_llm.side_effect = invoke_llm + mocker.patch.object(module, "ModelInstance", return_value=model_instance) + mocker.patch.object(module.db.session, "close", side_effect=lambda: events.append("close")) + + with patched_create_session(return_value=app_record): + runner.run(app_generate_entity, queue_manager, MagicMock(id="msg")) + + assert events == ["close", "invoke", "first-chunk"] + runner._handle_invoke_result.assert_called_once_with( + invoke_result=ANY, + queue_manager=queue_manager, + stream=True, + message_id="msg", + user_id="user", + tenant_id="tenant", + ) + def test_run_uses_low_image_detail_default(self, runner, mocker: MockerFixture): app_record = MagicMock(id="app1", tenant_id="tenant") - session = mocker.MagicMock() - session.scalar.return_value = app_record - mocker.patch.object(module.db, "session", session) - app_config = _build_app_config() app_generate_entity = _build_generate_entity(app_config, file_upload_config=None) @@ -155,7 +189,8 @@ class TestCompletionAppRunner: runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query")) runner.check_hosting_moderation = MagicMock(return_value=True) - runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) + with patched_create_session(return_value=app_record): + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) assert ( runner.organize_prompt_messages.call_args.kwargs["image_detail_config"] diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index 1eed76cf843..f24f799464f 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -53,6 +53,32 @@ def _build_app_generate_entity() -> SimpleNamespace: ) +def _patch_create_session(mocker: MockerFixture, session: MagicMock, *, events: list[str] | None = None): + """Patch create_session() to yield ``session`` inside its ``with`` body and ``begin()`` block. + + The runner now obtains short-lived sessions via ``create_session()`` instead of the + Flask scoped ``db.session``, so tests patch the module-level ``create_session`` and + hand back a context manager that yields the mock session. + """ + session_context = MagicMock() + + def enter_session(): + if events is not None: + events.append("session_enter") + return session + + def exit_session(*args): + if events is not None: + events.append("session_exit") + return False + + session_context.__enter__.side_effect = enter_session + session_context.__exit__.side_effect = exit_session + session.begin.return_value.__enter__.return_value = session + session.begin.return_value.__exit__.return_value = False + return mocker.patch.object(module, "create_session", return_value=session_context) + + @pytest.fixture def runner(): app_generate_entity = _build_app_generate_entity() @@ -77,13 +103,14 @@ def test_get_app_id(runner): assert runner._get_app_id() == "pipe" -def test_get_workflow_returns_workflow(mocker, runner): +def test_get_workflow_returns_workflow(runner): pipeline = MagicMock(tenant_id="tenant", id="pipe") workflow = MagicMock(id="wf") - mocker.patch.object(module.db, "session", MagicMock(scalar=MagicMock(return_value=workflow))) + session = MagicMock() + session.scalar.return_value = workflow - result = runner.get_workflow(pipeline=pipeline, workflow_id="wf") + result = runner.get_workflow(session=session, pipeline=pipeline, workflow_id="wf") assert result == workflow @@ -116,7 +143,7 @@ def test_update_document_status_on_failure(mocker, runner): session = MagicMock() session.scalar.return_value = document - mocker.patch.object(module.db, "session", session) + _patch_create_session(mocker, session) event = GraphRunFailedEvent(error="boom") @@ -124,7 +151,10 @@ def test_update_document_status_on_failure(mocker, runner): assert document.indexing_status == "error" assert document.error == "boom" - session.commit.assert_called_once() + session.add.assert_called_once_with(document) + session.begin.assert_called_once() + session.begin.return_value.__enter__.assert_called_once() + session.begin.return_value.__exit__.assert_called_once() def test_run_pipeline_not_found(mocker: MockerFixture): @@ -135,7 +165,7 @@ def test_run_pipeline_not_found(mocker: MockerFixture): session = MagicMock() session.get.side_effect = [None, None] - mocker.patch.object(module.db, "session", session) + _patch_create_session(mocker, session) runner = PipelineRunner( application_generate_entity=app_generate_entity, @@ -158,7 +188,7 @@ def test_run_workflow_not_initialized(mocker: MockerFixture): session = MagicMock() session.get.side_effect = [None, pipeline] - mocker.patch.object(module.db, "session", session) + _patch_create_session(mocker, session) runner = PipelineRunner( application_generate_entity=app_generate_entity, @@ -184,7 +214,7 @@ def test_run_single_iteration_path(mocker: MockerFixture): session = MagicMock() session.get.side_effect = [end_user, pipeline] - mocker.patch.object(module.db, "session", session) + _patch_create_session(mocker, session) runner = PipelineRunner( application_generate_entity=app_generate_entity, @@ -229,10 +259,11 @@ def test_run_normal_path_builds_graph(mocker: MockerFixture): pipeline = MagicMock(id="pipe") end_user = MagicMock(session_id="sess") + events = [] session = MagicMock() session.get.side_effect = [end_user, pipeline] - mocker.patch.object(module.db, "session", session) + _patch_create_session(mocker, session, events=events) workflow = MagicMock( id="wf", @@ -276,10 +307,11 @@ def test_run_normal_path_builds_graph(mocker: MockerFixture): workflow_entry = MagicMock() workflow_entry.graph_engine = MagicMock() - workflow_entry.run.return_value = [] + workflow_entry.run.side_effect = lambda: events.append("workflow_run") or [] mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry) mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock()) runner.run() + assert events == ["session_enter", "session_exit", "workflow_run"] runner._init_rag_pipeline_graph.assert_called_once() diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py index 2ed7c70ed94..e95ca8a7bf8 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock import pytest from pydantic import BaseModel from pytest_mock import MockerFixture +from sqlalchemy.dialects import postgresql from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation @@ -16,6 +17,16 @@ class _Chunk(BaseModel): value: int +def _build_app_model_config(result: dict | None = None): + app_model_config = MagicMock() + app_model_config.app_id = "app-1" + app_model_config.to_dict.return_value = result or { + "user_input_form": [{"name": "bar"}], + "annotation_reply": {"enabled": False}, + } + return app_model_config + + class TestBaseBackwardsInvocation: def test_convert_to_event_stream_with_generator_and_error(self): def _stream(): @@ -42,12 +53,25 @@ class TestBaseBackwardsInvocation: class TestPluginAppBackwardsInvocation: + def patch_create_session(self, mocker: MockerFixture, *, return_value=None, side_effect=None): + session = MagicMock() + if side_effect is not None: + session.scalar.side_effect = side_effect + else: + session.scalar.return_value = return_value + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + mocker.patch("core.plugin.backwards_invocation.app.create_session", return_value=session_ctx) + return session + def test_fetch_app_info_workflow_path(self, mocker: MockerFixture): workflow = MagicMock() workflow.features_dict = {"feature": "v"} workflow.user_input_form.return_value = [{"name": "foo"}] - app = MagicMock(mode=AppMode.WORKFLOW, workflow=workflow) + app = MagicMock(mode=AppMode.WORKFLOW) mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=workflow) mapper = mocker.patch( "core.plugin.backwards_invocation.app.get_parameters_from_feature_dict", return_value={"mapped": True}, @@ -59,10 +83,10 @@ class TestPluginAppBackwardsInvocation: mapper.assert_called_once_with(features_dict={"feature": "v"}, user_input_form=[{"name": "foo"}]) def test_fetch_app_info_model_config_path(self, mocker: MockerFixture): - model_config = MagicMock() - model_config.to_dict.return_value = {"user_input_form": [{"name": "bar"}], "k": "v"} - app = MagicMock(mode=AppMode.COMPLETION, app_model_config=model_config) + model_config_dict = {"user_input_form": [{"name": "bar"}], "k": "v"} + app = MagicMock(mode=AppMode.COMPLETION) mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app_model_config_dict", return_value=model_config_dict) mocker.patch( "core.plugin.backwards_invocation.app.get_parameters_from_feature_dict", return_value={"mapped": True}, @@ -85,8 +109,10 @@ class TestPluginAppBackwardsInvocation: def test_invoke_app_routes_by_mode(self, mocker: MockerFixture, mode, route_method): app = MagicMock(mode=mode) user = MagicMock() + workflow = MagicMock() mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=user) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=workflow) route = mocker.patch.object(PluginAppBackwardsInvocation, route_method, return_value={"routed": True}) result = PluginAppBackwardsInvocation.invoke_app( @@ -106,7 +132,9 @@ class TestPluginAppBackwardsInvocation: def test_invoke_app_uses_end_user_when_user_id_missing(self, mocker: MockerFixture): app = MagicMock(mode=AppMode.WORKFLOW) end_user = MagicMock() + workflow = MagicMock() mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=workflow) get_or_create = mocker.patch( "core.plugin.backwards_invocation.app.EndUserService.get_or_create_end_user", return_value=end_user, @@ -126,7 +154,8 @@ class TestPluginAppBackwardsInvocation: assert result == {"ok": True} get_or_create.assert_called_once_with(app) - assert route.call_args.args[1] is end_user + assert route.call_args.args[1] is workflow + assert route.call_args.args[2] is end_user def test_invoke_app_missing_query_for_chat_raises(self, mocker: MockerFixture): mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=MagicMock(mode=AppMode.CHAT)) @@ -190,7 +219,7 @@ class TestPluginAppBackwardsInvocation: app = MagicMock() app.mode = AppMode.ADVANCED_CHAT - app.workflow = workflow + mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=workflow) mocker.patch( "core.plugin.backwards_invocation.app.db", @@ -217,8 +246,9 @@ class TestPluginAppBackwardsInvocation: assert isinstance(pause_state_config, PauseStateLayerConfig) assert pause_state_config.state_owner_user_id == "owner-id" - def test_invoke_chat_app_advanced_chat_without_workflow_raises(self): - app = MagicMock(mode=AppMode.ADVANCED_CHAT, workflow=None) + def test_invoke_chat_app_advanced_chat_without_workflow_raises(self, mocker: MockerFixture): + app = MagicMock(mode=AppMode.ADVANCED_CHAT) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=None) with pytest.raises(ValueError, match="unexpected app type"): PluginAppBackwardsInvocation.invoke_chat_app( app=app, @@ -249,7 +279,6 @@ class TestPluginAppBackwardsInvocation: app = MagicMock() app.mode = AppMode.WORKFLOW - app.workflow = workflow mocker.patch( "core.plugin.backwards_invocation.app.db", @@ -262,6 +291,7 @@ class TestPluginAppBackwardsInvocation: result = PluginAppBackwardsInvocation.invoke_workflow_app( app=app, + workflow=workflow, user=MagicMock(), stream=False, inputs={"k": "v"}, @@ -274,12 +304,18 @@ class TestPluginAppBackwardsInvocation: assert isinstance(pause_state_config, PauseStateLayerConfig) assert pause_state_config.state_owner_user_id == "owner-id" - def test_invoke_workflow_app_without_workflow_raises(self): - app = MagicMock(mode=AppMode.WORKFLOW, workflow=None) + def test_invoke_app_workflow_without_workflow_raises(self, mocker: MockerFixture): + app = MagicMock(mode=AppMode.WORKFLOW) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=MagicMock()) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=None) with pytest.raises(ValueError, match="unexpected app type"): - PluginAppBackwardsInvocation.invoke_workflow_app( - app=app, - user=MagicMock(), + PluginAppBackwardsInvocation.invoke_app( + app_id="app", + user_id="user", + tenant_id="tenant", + conversation_id=None, + query=None, stream=False, inputs={}, files=[], @@ -297,58 +333,97 @@ class TestPluginAppBackwardsInvocation: assert spy.call_count == 1 def test_get_user_returns_end_user(self, mocker: MockerFixture): - session = MagicMock() - session.scalar.side_effect = [MagicMock(id="end-user")] - session_ctx = MagicMock() - session_ctx.__enter__.return_value = session - session_ctx.__exit__.return_value = None - mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx) - mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock())) + session = self.patch_create_session(mocker, side_effect=[MagicMock(id="end-user")]) + app = SimpleNamespace(id="app-1", tenant_id="tenant-1") + + user = PluginAppBackwardsInvocation._get_user("uid", app) - user = PluginAppBackwardsInvocation._get_user("uid") assert user.id == "end-user" + stmt = session.scalar.call_args_list[0].args[0] + compiled = str(stmt.compile(dialect=postgresql.dialect())) + assert "end_users.id" in compiled + assert "end_users.tenant_id" in compiled + assert "end_users.app_id" in compiled + assert stmt.compile().params == {"id_1": "uid", "tenant_id_1": "tenant-1", "app_id_1": "app-1"} def test_get_user_falls_back_to_account_user(self, mocker: MockerFixture): - session = MagicMock() - session.scalar.side_effect = [None, MagicMock(id="account-user")] - session_ctx = MagicMock() - session_ctx.__enter__.return_value = session - session_ctx.__exit__.return_value = None - mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx) - mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock())) + session = self.patch_create_session(mocker, side_effect=[None, MagicMock(id="account-user")]) + app = SimpleNamespace(id="app-1", tenant_id="tenant-1") + + user = PluginAppBackwardsInvocation._get_user("uid", app) - user = PluginAppBackwardsInvocation._get_user("uid") assert user.id == "account-user" + stmt = session.scalar.call_args_list[1].args[0] + compiled = str(stmt.compile(dialect=postgresql.dialect())) + assert "accounts.id" in compiled + assert "tenant_account_joins.account_id" in compiled + assert "tenant_account_joins.tenant_id" in compiled + assert stmt.compile().params == {"id_1": "uid", "tenant_id_1": "tenant-1"} def test_get_user_raises_when_user_not_found(self, mocker: MockerFixture): - session = MagicMock() - session.scalar.side_effect = [None, None] - session_ctx = MagicMock() - session_ctx.__enter__.return_value = session - session_ctx.__exit__.return_value = None - mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx) - mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock())) + self.patch_create_session(mocker, side_effect=[None, None]) + app = SimpleNamespace(id="app-1", tenant_id="tenant-1") with pytest.raises(ValueError, match="user not found"): - PluginAppBackwardsInvocation._get_user("uid") + PluginAppBackwardsInvocation._get_user("uid", app) def test_get_app_returns_app(self, mocker: MockerFixture): app_obj = MagicMock(id="app") - db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj))) - mocker.patch("core.plugin.backwards_invocation.app.db", db) + self.patch_create_session(mocker, return_value=app_obj) assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj def test_get_app_raises_when_missing(self, mocker: MockerFixture): - db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None))) - mocker.patch("core.plugin.backwards_invocation.app.db", db) + self.patch_create_session(mocker, return_value=None) with pytest.raises(ValueError, match="app not found"): PluginAppBackwardsInvocation._get_app("app", "tenant") def test_get_app_raises_when_query_fails(self, mocker: MockerFixture): - db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down")))) - mocker.patch("core.plugin.backwards_invocation.app.db", db) + self.patch_create_session(mocker, side_effect=RuntimeError("db down")) with pytest.raises(ValueError, match="app not found"): PluginAppBackwardsInvocation._get_app("app", "tenant") + + def test_get_workflow_stays_inside_app_boundary(self, mocker: MockerFixture): + workflow = MagicMock(id="workflow") + session = self.patch_create_session(mocker, return_value=workflow) + app = SimpleNamespace(id="app-1", tenant_id="tenant-1", workflow_id="workflow-1") + + assert PluginAppBackwardsInvocation._get_workflow(app) is workflow + + stmt = session.scalar.call_args.args[0] + compiled = str(stmt.compile(dialect=postgresql.dialect())) + assert "workflows.id" in compiled + assert "workflows.tenant_id" in compiled + assert "workflows.app_id" in compiled + assert stmt.compile().params == { + "id_1": "workflow-1", + "tenant_id_1": "tenant-1", + "app_id_1": "app-1", + "param_1": 1, + } + + def test_get_app_model_config_dict_uses_explicit_session_for_annotation_reply(self, mocker: MockerFixture): + annotation_reply = {"enabled": False} + app_model_config = _build_app_model_config() + session = self.patch_create_session(mocker, return_value=app_model_config) + load_annotation_reply_config = mocker.patch( + "core.plugin.backwards_invocation.app.load_annotation_reply_config", + return_value=annotation_reply, + ) + app = SimpleNamespace(id="app-1", app_model_config_id="config-1") + + result = PluginAppBackwardsInvocation._get_app_model_config_dict(app) + + assert result is not None + assert result["user_input_form"] == [{"name": "bar"}] + assert result["annotation_reply"] == annotation_reply + load_annotation_reply_config.assert_called_once_with(session, "app-1") + app_model_config.to_dict.assert_called_once_with(annotation_reply=annotation_reply) + + stmt = session.scalar.call_args.args[0] + compiled = str(stmt.compile(dialect=postgresql.dialect())) + assert "app_model_configs.id" in compiled + assert "app_model_configs.app_id" in compiled + assert stmt.compile().params == {"id_1": "config-1", "app_id_1": "app-1", "param_1": 1} diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index d3d0c5dce00..684d7f9fa8e 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -12,10 +12,11 @@ import json from datetime import UTC, datetime from decimal import Decimal from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch from uuid import uuid4 import pytest +from sqlalchemy.dialects import postgresql from models.enums import ConversationFromSource from models.model import ( @@ -29,6 +30,7 @@ from models.model import ( Message, MessageAnnotation, Site, + load_annotation_reply_config, ) @@ -342,6 +344,70 @@ class TestAppModelConfig: # Assert assert result == questions + def test_to_dict_uses_injected_annotation_reply(self): + config = AppModelConfig(app_id=str(uuid4())) + annotation_reply = {"enabled": False} + + with patch.object( + AppModelConfig, + "annotation_reply_dict", + new_callable=PropertyMock, + side_effect=AssertionError("annotation_reply_dict should not be accessed"), + ): + result = config.to_dict(annotation_reply=annotation_reply) + + assert result["annotation_reply"] == annotation_reply + + +class TestAnnotationReplyConfigLoader: + def test_load_annotation_reply_config_returns_disabled_when_setting_missing(self): + session = MagicMock() + session.scalar.return_value = None + + result = load_annotation_reply_config(session, "app-1") + + assert result == {"enabled": False} + session.scalar.assert_called_once() + stmt = session.scalar.call_args.args[0] + compiled = str(stmt.compile(dialect=postgresql.dialect())) + assert "app_annotation_settings.app_id" in compiled + assert stmt.compile().params == {"app_id_1": "app-1"} + + def test_load_annotation_reply_config_returns_embedding_model(self): + session = MagicMock() + annotation_setting = SimpleNamespace( + id="annotation-1", + score_threshold=0.7, + collection_binding_id="binding-1", + ) + collection_binding = SimpleNamespace(provider_name="provider", model_name="embedding") + session.scalar.side_effect = [annotation_setting, collection_binding] + + result = load_annotation_reply_config(session, "app-1") + + assert result == { + "id": "annotation-1", + "enabled": True, + "score_threshold": 0.7, + "embedding_model": { + "embedding_provider_name": "provider", + "embedding_model_name": "embedding", + }, + } + assert session.scalar.call_count == 2 + stmt = session.scalar.call_args_list[1].args[0] + compiled = str(stmt.compile(dialect=postgresql.dialect())) + assert "dataset_collection_bindings.id" in compiled + assert stmt.compile().params == {"id_1": "binding-1"} + + def test_load_annotation_reply_config_raises_when_binding_missing(self): + session = MagicMock() + annotation_setting = SimpleNamespace(collection_binding_id="binding-1") + session.scalar.side_effect = [annotation_setting, None] + + with pytest.raises(ValueError, match="Collection binding detail not found"): + load_annotation_reply_config(session, "app-1") + class TestConversationModel: """Test suite for Conversation model integrity."""