chore: not use request.scoped session (#37421)

Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
wangxiaolei 2026-06-23 03:38:24 +08:00 committed by GitHub
parent 7d2f25df8e
commit 0cc27dd401
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 633 additions and 253 deletions

View File

@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -22,7 +22,7 @@ from core.app.entities.queue_entities import (
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.db.session_factory import session_factory
from core.db.session_factory import create_session, session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
@ -107,7 +107,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow_execution_id=self.application_generate_entity.workflow_run_id,
)
with Session(db.engine, expire_on_commit=False) as session:
with create_session() as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
if not app_record:
@ -204,6 +204,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
)
# Release the Flask scoped session before workflow execution so a checked-out DB connection
# is not held for the lifetime of the graph run.
db.session.close()
# RUN WORKFLOW
@ -368,7 +370,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
:return: List of conversation variables ready for use
"""
with sessionmaker(bind=db.engine).begin() as session:
with create_session() as session, session.begin():
existing_variables = self._load_existing_conversation_variables(session)
if not existing_variables:

View File

@ -12,10 +12,10 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.db.session_factory import create_session
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.moderation.base import ModerationError
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
@ -47,7 +47,10 @@ class AgentChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
app_stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(app_stmt)
with create_session() as session:
app_record = session.scalar(app_stmt)
if app_record:
session.expunge(app_record)
if not app_record:
raise ValueError("App not found")
@ -185,14 +188,18 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
conversation_result = db.session.scalar(conversation_stmt)
if conversation_result is None:
raise ValueError("Conversation not found")
msg_stmt = select(Message).where(Message.id == message.id)
message_result = db.session.scalar(msg_stmt)
with create_session() as session:
conversation_result = session.scalar(conversation_stmt)
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = session.scalar(msg_stmt)
if message_result is not None:
session.expunge(message_result)
session.expunge(conversation_result)
if message_result is None:
raise ValueError("Message not found")
db.session.close()
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
# start agent runner

View File

@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.db.session_factory import create_session
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.moderation.base import ModerationError
@ -46,7 +47,10 @@ class ChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
with create_session() as session:
app_record = session.scalar(stmt)
if app_record:
session.expunge(app_record)
if not app_record:
raise ValueError("App not found")
@ -216,6 +220,8 @@ class ChatAppRunner(AppRunner):
model=application_generate_entity.model_conf.model,
)
# Release the Flask scoped session before LLM streaming so a checked-out DB connection
# is not held for the lifetime of the provider response.
db.session.close()
invoke_result = model_instance.invoke_llm(

View File

@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
CompletionAppGenerateEntity,
)
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.db.session_factory import create_session
from core.model_manager import ModelInstance
from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@ -39,7 +40,10 @@ class CompletionAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
with create_session() as session:
app_record = session.scalar(stmt)
if app_record:
session.expunge(app_record)
if not app_record:
raise ValueError("App not found")
@ -174,6 +178,8 @@ class CompletionAppRunner(AppRunner):
model=application_generate_entity.model_conf.model,
)
# Release the Flask scoped session before LLM streaming so a checked-out DB connection
# is not held for the lifetime of the provider response.
db.session.close()
invoke_result = model_instance.invoke_llm(

View File

@ -3,6 +3,7 @@ import time
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
@ -14,12 +15,12 @@ from core.app.entities.app_invoke_entities import (
build_dify_run_context,
)
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.db.session_factory import create_session
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from core.workflow.node_factory import DifyGraphInitContext, DifyNodeFactory, get_default_root_node_id
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from graphon.enums import WorkflowType
from graphon.graph import Graph
from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
@ -83,22 +84,24 @@ class PipelineRunner(WorkflowBasedAppRunner):
user_from = self._resolve_user_from(invoke_from)
user_id = None
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.get(EndUser, self.application_generate_entity.user_id)
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
with create_session() as session:
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = session.get(EndUser, self.application_generate_entity.user_id)
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.get(Pipeline, app_config.app_id)
if not pipeline:
raise ValueError("Pipeline not found")
pipeline = session.get(Pipeline, app_config.app_id)
if not pipeline:
raise ValueError("Pipeline not found")
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
workflow = self.get_workflow(session=session, pipeline=pipeline, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
db.session.close()
session.expunge(pipeline)
session.expunge(workflow)
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
@ -208,12 +211,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
)
self._handle_event(workflow_entry, event)
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
def get_workflow(self, session: Session, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.scalar(
workflow = session.scalar(
select(Workflow)
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
.limit(1)
@ -298,11 +301,11 @@ class PipelineRunner(WorkflowBasedAppRunner):
"""
if isinstance(event, GraphRunFailedEvent):
if document_id and dataset_id:
document = db.session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
if document:
document.indexing_status = "error"
document.error = event.error or "Unknown error"
db.session.add(document)
db.session.commit()
with create_session() as session, session.begin():
document = session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
if document:
document.indexing_status = "error"
document.error = event.error or "Unknown error"
session.add(document)

View File

@ -3,7 +3,6 @@ from collections.abc import Generator, Mapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@ -13,10 +12,19 @@ from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.db.session_factory import create_session
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from extensions.ext_database import db
from models import Account
from models.model import App, AppMode, EndUser
from models import Account, TenantAccountJoin
from models.model import (
App,
AppMode,
AppModelConfig,
AppModelConfigDict,
EndUser,
load_annotation_reply_config,
)
from models.workflow import Workflow
from services.end_user_service import EndUserService
@ -30,18 +38,18 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""Retrieve app parameters."""
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app.workflow
workflow = cls._get_workflow(app)
if workflow is None:
raise ValueError("unexpected app type")
features_dict: dict[str, Any] = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
app_model_config_dict = cls._get_app_model_config_dict(app)
if app_model_config_dict is None:
raise ValueError("unexpected app type")
features_dict = cast(dict[str, Any], app_model_config.to_dict())
features_dict = cast(dict[str, Any], app_model_config_dict)
user_input_form = features_dict.get("user_input_form", [])
@ -68,7 +76,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
if not user_id:
user = EndUserService.get_or_create_end_user(app)
else:
user = cls._get_user(user_id)
user = cls._get_user(user_id, app)
conversation_id = conversation_id or ""
@ -79,7 +87,10 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
case AppMode.WORKFLOW:
return cls.invoke_workflow_app(app, user, stream, inputs, files)
workflow = cls._get_workflow(app)
if not workflow:
raise ValueError("unexpected app type")
return cls.invoke_workflow_app(app, workflow, user, stream, inputs, files)
case AppMode.COMPLETION:
return cls.invoke_completion_app(app, user, stream, inputs, files)
case _:
@ -101,7 +112,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
match app.mode:
case AppMode.ADVANCED_CHAT:
workflow = app.workflow
workflow = cls._get_workflow(app)
if not workflow:
raise ValueError("unexpected app type")
@ -158,6 +169,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
def invoke_workflow_app(
cls,
app: App,
workflow: Workflow,
user: EndUser | Account,
stream: bool,
inputs: Mapping,
@ -166,10 +178,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke workflow app
"""
workflow = app.workflow
if not workflow:
raise ValueError("unexpected app type")
pause_config = PauseStateLayerConfig(
session_factory=db.engine,
state_owner_user_id=workflow.created_by,
@ -207,16 +215,26 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
)
@classmethod
def _get_user(cls, user_id: str) -> EndUser | Account:
def _get_user(cls, user_id: str, app: App) -> EndUser | Account:
"""
get the user by user id
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(EndUser).where(EndUser.id == user_id)
with create_session() as session:
stmt = select(EndUser).where(
EndUser.id == user_id,
EndUser.tenant_id == app.tenant_id,
EndUser.app_id == app.id,
)
user = session.scalar(stmt)
if not user:
stmt = select(Account).where(Account.id == user_id)
stmt = select(Account).where(
Account.id == user_id,
Account.id == TenantAccountJoin.account_id,
TenantAccountJoin.tenant_id == app.tenant_id,
)
user = session.scalar(stmt)
if user:
session.expunge(user)
if not user:
raise ValueError("user not found")
@ -229,7 +247,10 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get app
"""
try:
app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
with create_session() as session:
app = session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
if app:
session.expunge(app)
except Exception:
raise ValueError("app not found")
@ -237,3 +258,41 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
raise ValueError("app not found")
return app
@classmethod
def _get_workflow(cls, app: App) -> Workflow | None:
"""
get workflow without relying on App.workflow's request-scoped session property
"""
if not app.workflow_id:
return None
with create_session() as session:
workflow = session.scalar(
select(Workflow)
.where(Workflow.id == app.workflow_id, Workflow.tenant_id == app.tenant_id, Workflow.app_id == app.id)
.limit(1)
)
if workflow:
session.expunge(workflow)
return workflow
@classmethod
def _get_app_model_config_dict(cls, app: App) -> AppModelConfigDict | None:
"""
get app model config features without relying on request-scoped session-backed model properties
"""
if not app.app_model_config_id:
return None
with create_session() as session:
app_model_config = session.scalar(
select(AppModelConfig)
.where(AppModelConfig.id == app.app_model_config_id, AppModelConfig.app_id == app.id)
.limit(1)
)
if app_model_config is None:
return None
annotation_reply = load_annotation_reply_config(session, app_model_config.app_id)
return app_model_config.to_dict(annotation_reply=annotation_reply)

View File

@ -774,26 +774,7 @@ class AppModelConfig(TypeBase):
@property
def annotation_reply_dict(self) -> AnnotationReplyConfig:
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
if not collection_binding_detail:
raise ValueError("Collection binding detail not found")
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
else:
return {"enabled": False}
return load_annotation_reply_config(db.session(), self.app_id)
@property
def more_like_this_dict(self) -> EnabledConfig:
@ -864,7 +845,7 @@ class AppModelConfig(TypeBase):
},
)
def to_dict(self) -> AppModelConfigDict:
def to_dict(self, *, annotation_reply: AnnotationReplyConfig | None = None) -> AppModelConfigDict:
return {
"opening_statement": self.opening_statement,
"suggested_questions": self.suggested_questions_list,
@ -872,7 +853,7 @@ class AppModelConfig(TypeBase):
"speech_to_text": self.speech_to_text_dict,
"text_to_speech": self.text_to_speech_dict,
"retriever_resource": self.retriever_resource_dict,
"annotation_reply": self.annotation_reply_dict,
"annotation_reply": annotation_reply if annotation_reply is not None else self.annotation_reply_dict,
"more_like_this": self.more_like_this_dict,
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
"external_data_tools": self.external_data_tools_list,
@ -2038,6 +2019,30 @@ class AppAnnotationSetting(TypeBase):
)
def load_annotation_reply_config(session: Session, app_id: str) -> AnnotationReplyConfig:
annotation_setting = session.scalar(select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id))
if annotation_setting is None:
return {"enabled": False}
from .dataset import DatasetCollectionBinding
collection_binding_detail = session.scalar(
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == annotation_setting.collection_binding_id)
)
if collection_binding_detail is None:
raise ValueError("Collection binding detail not found")
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
class OperationLog(TypeBase):
__tablename__ = "operation_logs"
__table_args__ = (

View File

@ -25,6 +25,15 @@ MINIMAL_GRAPH = {
}
def _patch_create_session(mock_session: MagicMock):
session_context = MagicMock()
session_context.__enter__.return_value = mock_session
session_context.__exit__.return_value = False
mock_session.begin.return_value.__enter__.return_value = mock_session
mock_session.begin.return_value.__exit__.return_value = False
return patch("core.app.apps.advanced_chat.app_runner.create_session", return_value=session_context)
class TestAdvancedChatAppRunnerConversationVariables:
"""Test that AdvancedChatAppRunner correctly handles conversation variables."""
@ -135,10 +144,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
_patch_create_session(mock_session),
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
patch.object(runner, "_init_graph") as mock_init_graph,
patch.object(
runner,
@ -151,12 +158,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.engine = MagicMock()
# Mock GraphRuntimeState to accept the variable pool
mock_graph_runtime_state_class.return_value = MagicMock()
@ -281,10 +282,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
_patch_create_session(mock_session),
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
patch.object(runner, "_init_graph") as mock_init_graph,
patch.object(
runner,
@ -298,12 +297,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.engine = MagicMock()
# Mock ConversationVariable.from_variable to return mock objects
mock_conv_vars = []
for var in workflow_vars:
@ -434,10 +427,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
_patch_create_session(mock_session),
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
patch.object(runner, "_init_graph") as mock_init_graph,
patch.object(
runner,
@ -450,12 +441,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.engine = MagicMock()
# Mock GraphRuntimeState to accept the variable pool
mock_graph_runtime_state_class.return_value = MagicMock()

View File

@ -3,6 +3,7 @@ from uuid import uuid4
import pytest
import core.app.apps.advanced_chat.app_runner as module
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.queue_entities import QueueStopEvent
@ -85,27 +86,24 @@ def build_runner():
def _patch_common_run_deps(runner: AdvancedChatAppRunner):
"""Context manager that patches common heavy deps used by run()."""
# create_session() returns a context manager whose body yields a session that
# supports both scalar() (app record lookup) and begin()/scalars().all()
# (conversation variable initialization).
mock_session = MagicMock()
mock_session.scalar.return_value = MagicMock()
mock_session.scalars.return_value.all.return_value = []
session_context = MagicMock()
session_context.__enter__.return_value = mock_session
session_context.__exit__.return_value = False
mock_session.begin.return_value.__enter__.return_value = mock_session
mock_session.begin.return_value.__exit__.return_value = False
return patch.multiple(
"core.app.apps.advanced_chat.app_runner",
Session=MagicMock(
return_value=MagicMock(
__enter__=lambda s: s,
__exit__=lambda *a, **k: False,
scalar=lambda *a, **k: MagicMock(),
),
),
sessionmaker=MagicMock(
return_value=MagicMock(
begin=MagicMock(
return_value=MagicMock(
__enter__=lambda s: MagicMock(scalars=MagicMock(return_value=MagicMock(all=lambda: []))),
__exit__=lambda *a, **k: False,
),
),
),
),
create_session=MagicMock(return_value=session_context),
select=MagicMock(),
db=MagicMock(engine=MagicMock()),
session_factory=MagicMock(get_session_maker=MagicMock(return_value=MagicMock())),
RedisChannel=MagicMock(),
redis_client=MagicMock(),
WorkflowEntry=MagicMock(**{"return_value.run.return_value": iter([])}),
@ -192,3 +190,42 @@ def test_run_returns_early_when_direct_output_via_handle_input_moderation(build_
# Ensure no further steps executed
mock_anno.assert_not_called()
mock_init_graph.assert_not_called()
def test_run_closes_scoped_session_before_workflow_run(build_runner):
runner = build_runner
events = []
mock_session = MagicMock()
mock_session.scalar.return_value = MagicMock()
session_context = MagicMock()
session_context.__enter__.return_value = mock_session
session_context.__exit__.return_value = False
workflow_entry = MagicMock()
def run_workflow():
events.append("run")
return iter([])
workflow_entry.run.side_effect = run_workflow
with (
patch.object(module, "create_session", return_value=session_context),
patch.object(module, "session_factory", MagicMock(get_session_maker=MagicMock(return_value=MagicMock()))),
patch.object(module, "RedisChannel"),
patch.object(module, "redis_client"),
patch.object(module, "WorkflowEntry", return_value=workflow_entry),
patch.object(module.db.session, "close", side_effect=lambda: events.append("close")),
patch.object(
runner,
"handle_input_moderation",
return_value=(False, runner.application_generate_entity.inputs, runner.application_generate_entity.query),
),
patch.object(runner, "handle_annotation_reply", return_value=False),
patch.object(runner, "_initialize_conversation_variables", return_value=[]),
patch.object(runner, "_init_graph", return_value=MagicMock()),
):
runner.run()
assert events == ["close", "run"]

View File

@ -13,12 +13,24 @@ def runner():
return AgentChatAppRunner()
def patch_create_session(mocker: MockerFixture, *, return_value=None, side_effect=None):
session = mocker.MagicMock()
if side_effect is not None:
session.scalar.side_effect = side_effect
else:
session.scalar.return_value = return_value
session_context = mocker.MagicMock()
session_context.__enter__.return_value = session
mocker.patch("core.app.apps.agent_chat.app_runner.create_session", return_value=session_context)
return session
class TestAgentChatAppRunnerRun:
def test_run_app_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture):
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock())
generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None)
patch_create_session(mocker, return_value=None)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
@ -37,7 +49,7 @@ class TestAgentChatAppRunnerRun:
conversation_id=None,
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad"))
mocker.patch.object(runner, "direct_output")
@ -62,7 +74,7 @@ class TestAgentChatAppRunnerRun:
invoke_from=mocker.MagicMock(),
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
annotation = mocker.MagicMock(id="anno", content="answer")
@ -91,7 +103,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -121,7 +133,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -163,7 +175,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -179,10 +191,7 @@ class TestAgentChatAppRunnerRun:
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
patch_create_session(mocker, side_effect=[app_record, conversation, message])
runner_cls = mocker.MagicMock()
mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls)
@ -219,7 +228,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -235,10 +244,7 @@ class TestAgentChatAppRunnerRun:
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
patch_create_session(mocker, side_effect=[app_record, conversation, message])
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
@ -267,7 +273,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -283,10 +289,7 @@ class TestAgentChatAppRunnerRun:
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
patch_create_session(mocker, side_effect=[app_record, conversation, message])
runner_cls = mocker.MagicMock()
mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls)
@ -323,10 +326,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, None],
)
patch_create_session(mocker, side_effect=[app_record, None])
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -357,10 +357,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, mocker.MagicMock(id="conv"), None],
)
patch_create_session(mocker, side_effect=[app_record, mocker.MagicMock(id="conv"), None])
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -391,7 +388,7 @@ class TestAgentChatAppRunnerRun:
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
patch_create_session(mocker, return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
@ -407,10 +404,7 @@ class TestAgentChatAppRunnerRun:
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
patch_create_session(mocker, side_effect=[app_record, conversation, message])
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), conversation, message)

View File

@ -1,5 +1,6 @@
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import Mock, patch
from unittest.mock import ANY, MagicMock, Mock, patch
import pytest
@ -29,6 +30,19 @@ class DummyQueueManager:
self.published.append((event, pub_from))
@contextmanager
def patched_create_session(*, return_value=None, side_effect=None):
session = MagicMock()
if side_effect is not None:
session.scalar.side_effect = side_effect
else:
session.scalar.return_value = return_value
session_context = MagicMock()
session_context.__enter__.return_value = session
with patch("core.app.apps.chat.app_runner.create_session", return_value=session_context):
yield session
class TestChatAppGenerator:
def test_generate_requires_query(self):
generator = ChatAppGenerator()
@ -167,7 +181,7 @@ class TestChatAppRunner:
invoke_from=InvokeFrom.SERVICE_API,
)
with patch("core.app.apps.chat.app_runner.db.session.scalar", return_value=None):
with patched_create_session(return_value=None):
with pytest.raises(ValueError):
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
@ -195,10 +209,7 @@ class TestChatAppRunner:
)
with (
patch(
"core.app.apps.chat.app_runner.db.session.scalar",
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
),
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", side_effect=ModerationError("blocked")),
patch.object(ChatAppRunner, "direct_output") as mock_direct,
@ -233,10 +244,7 @@ class TestChatAppRunner:
annotation = SimpleNamespace(id="ann-1", content="answer")
with (
patch(
"core.app.apps.chat.app_runner.db.session.scalar",
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
),
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=annotation),
@ -272,13 +280,73 @@ class TestChatAppRunner:
)
with (
patch(
"core.app.apps.chat.app_runner.db.session.scalar",
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
),
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None),
patch.object(ChatAppRunner, "check_hosting_moderation", return_value=True),
):
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
def test_run_closes_scoped_session_before_stream_consumption(self):
runner = ChatAppRunner()
app_config = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
prompt_template=None,
external_data_variables=[],
dataset=None,
additional_features=None,
)
app_generate_entity = DummyGenerateEntity(
app_config=app_config,
model_conf=SimpleNamespace(provider_model_bundle=None, model="model-1", parameters={}),
inputs={},
query="hi",
files=[],
file_upload_config=None,
conversation_id=None,
stream=True,
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
events = []
queue_manager = DummyQueueManager()
model_instance = MagicMock()
def invoke_stream():
events.append("first-chunk")
yield "chunk"
def invoke_llm(**kwargs):
events.append("invoke")
return invoke_stream()
with (
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None),
patch.object(ChatAppRunner, "check_hosting_moderation", return_value=False),
patch.object(ChatAppRunner, "recalc_llm_max_tokens"),
patch.object(
ChatAppRunner,
"_handle_invoke_result",
side_effect=lambda invoke_result, **kwargs: list(invoke_result),
) as mock_handle,
patch("core.app.apps.chat.app_runner.ModelInstance", return_value=model_instance),
patch("core.app.apps.chat.app_runner.db.session.close", side_effect=lambda: events.append("close")),
):
model_instance.invoke_llm.side_effect = invoke_llm
runner.run(app_generate_entity, queue_manager, SimpleNamespace(), SimpleNamespace(id="m1"))
assert events == ["close", "invoke", "first-chunk"]
mock_handle.assert_called_once_with(
invoke_result=ANY,
queue_manager=queue_manager,
stream=True,
message_id="m1",
user_id="user-1",
tenant_id="tenant-1",
)

View File

@ -1,5 +1,6 @@
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock
from unittest.mock import ANY, MagicMock, patch
import pytest
from pytest_mock import MockerFixture
@ -47,25 +48,28 @@ def _build_generate_entity(app_config, file_upload_config=None):
)
@contextmanager
def patched_create_session(*, return_value=None):
session = MagicMock()
session.scalar.return_value = return_value
session_context = MagicMock()
session_context.__enter__.return_value = session
with patch.object(module, "create_session", return_value=session_context):
yield session
class TestCompletionAppRunner:
def test_run_app_not_found(self, runner, mocker: MockerFixture):
session = mocker.MagicMock()
session.scalar.return_value = None
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
with pytest.raises(ValueError):
runner.run(app_generate_entity, MagicMock(), MagicMock())
with patched_create_session(return_value=None):
with pytest.raises(ValueError):
runner.run(app_generate_entity, MagicMock(), MagicMock())
def test_run_moderation_error_outputs_direct(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
@ -74,7 +78,8 @@ class TestCompletionAppRunner:
runner.direct_output = MagicMock()
runner._handle_invoke_result = MagicMock()
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
with patched_create_session(return_value=app_record):
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
runner.direct_output.assert_called_once()
runner._handle_invoke_result.assert_not_called()
@ -82,10 +87,6 @@ class TestCompletionAppRunner:
def test_run_hosting_moderation_stops(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
@ -94,18 +95,14 @@ class TestCompletionAppRunner:
runner.check_hosting_moderation = MagicMock(return_value=True)
runner._handle_invoke_result = MagicMock()
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
with patched_create_session(return_value=app_record):
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
runner._handle_invoke_result.assert_not_called()
def test_run_dataset_and_external_tools_flow(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
session.close = MagicMock()
mocker.patch.object(module.db, "session", session)
retrieve_config = MagicMock(query_variable="qvar")
dataset_config = MagicMock(dataset_ids=["ds"], retrieve_config=retrieve_config)
additional_features = MagicMock(show_retrieve_source=True)
@ -135,19 +132,56 @@ class TestCompletionAppRunner:
model_instance.invoke_llm.return_value = "invoke_result"
mocker.patch.object(module, "ModelInstance", return_value=model_instance)
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant"))
with patched_create_session(return_value=app_record):
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant"))
dataset_retrieval.retrieve.assert_called_once()
assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input"
runner._handle_invoke_result.assert_called_once()
def test_run_closes_scoped_session_before_stream_consumption(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
queue_manager = MagicMock()
events = []
runner.organize_prompt_messages = MagicMock(return_value=([], None))
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
runner.check_hosting_moderation = MagicMock(return_value=False)
runner.recalc_llm_max_tokens = MagicMock()
runner._handle_invoke_result = MagicMock(side_effect=lambda invoke_result, **kwargs: list(invoke_result))
model_instance = MagicMock()
def invoke_stream():
events.append("first-chunk")
yield "chunk"
def invoke_llm(**kwargs):
events.append("invoke")
return invoke_stream()
model_instance.invoke_llm.side_effect = invoke_llm
mocker.patch.object(module, "ModelInstance", return_value=model_instance)
mocker.patch.object(module.db.session, "close", side_effect=lambda: events.append("close"))
with patched_create_session(return_value=app_record):
runner.run(app_generate_entity, queue_manager, MagicMock(id="msg"))
assert events == ["close", "invoke", "first-chunk"]
runner._handle_invoke_result.assert_called_once_with(
invoke_result=ANY,
queue_manager=queue_manager,
stream=True,
message_id="msg",
user_id="user",
tenant_id="tenant",
)
def test_run_uses_low_image_detail_default(self, runner, mocker: MockerFixture):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config, file_upload_config=None)
@ -155,7 +189,8 @@ class TestCompletionAppRunner:
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
runner.check_hosting_moderation = MagicMock(return_value=True)
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
with patched_create_session(return_value=app_record):
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
assert (
runner.organize_prompt_messages.call_args.kwargs["image_detail_config"]

View File

@ -53,6 +53,32 @@ def _build_app_generate_entity() -> SimpleNamespace:
)
def _patch_create_session(mocker: MockerFixture, session: MagicMock, *, events: list[str] | None = None):
"""Patch create_session() to yield ``session`` inside its ``with`` body and ``begin()`` block.
The runner now obtains short-lived sessions via ``create_session()`` instead of the
Flask scoped ``db.session``, so tests patch the module-level ``create_session`` and
hand back a context manager that yields the mock session.
"""
session_context = MagicMock()
def enter_session():
if events is not None:
events.append("session_enter")
return session
def exit_session(*args):
if events is not None:
events.append("session_exit")
return False
session_context.__enter__.side_effect = enter_session
session_context.__exit__.side_effect = exit_session
session.begin.return_value.__enter__.return_value = session
session.begin.return_value.__exit__.return_value = False
return mocker.patch.object(module, "create_session", return_value=session_context)
@pytest.fixture
def runner():
app_generate_entity = _build_app_generate_entity()
@ -77,13 +103,14 @@ def test_get_app_id(runner):
assert runner._get_app_id() == "pipe"
def test_get_workflow_returns_workflow(mocker, runner):
def test_get_workflow_returns_workflow(runner):
pipeline = MagicMock(tenant_id="tenant", id="pipe")
workflow = MagicMock(id="wf")
mocker.patch.object(module.db, "session", MagicMock(scalar=MagicMock(return_value=workflow)))
session = MagicMock()
session.scalar.return_value = workflow
result = runner.get_workflow(pipeline=pipeline, workflow_id="wf")
result = runner.get_workflow(session=session, pipeline=pipeline, workflow_id="wf")
assert result == workflow
@ -116,7 +143,7 @@ def test_update_document_status_on_failure(mocker, runner):
session = MagicMock()
session.scalar.return_value = document
mocker.patch.object(module.db, "session", session)
_patch_create_session(mocker, session)
event = GraphRunFailedEvent(error="boom")
@ -124,7 +151,10 @@ def test_update_document_status_on_failure(mocker, runner):
assert document.indexing_status == "error"
assert document.error == "boom"
session.commit.assert_called_once()
session.add.assert_called_once_with(document)
session.begin.assert_called_once()
session.begin.return_value.__enter__.assert_called_once()
session.begin.return_value.__exit__.assert_called_once()
def test_run_pipeline_not_found(mocker: MockerFixture):
@ -135,7 +165,7 @@ def test_run_pipeline_not_found(mocker: MockerFixture):
session = MagicMock()
session.get.side_effect = [None, None]
mocker.patch.object(module.db, "session", session)
_patch_create_session(mocker, session)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
@ -158,7 +188,7 @@ def test_run_workflow_not_initialized(mocker: MockerFixture):
session = MagicMock()
session.get.side_effect = [None, pipeline]
mocker.patch.object(module.db, "session", session)
_patch_create_session(mocker, session)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
@ -184,7 +214,7 @@ def test_run_single_iteration_path(mocker: MockerFixture):
session = MagicMock()
session.get.side_effect = [end_user, pipeline]
mocker.patch.object(module.db, "session", session)
_patch_create_session(mocker, session)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
@ -229,10 +259,11 @@ def test_run_normal_path_builds_graph(mocker: MockerFixture):
pipeline = MagicMock(id="pipe")
end_user = MagicMock(session_id="sess")
events = []
session = MagicMock()
session.get.side_effect = [end_user, pipeline]
mocker.patch.object(module.db, "session", session)
_patch_create_session(mocker, session, events=events)
workflow = MagicMock(
id="wf",
@ -276,10 +307,11 @@ def test_run_normal_path_builds_graph(mocker: MockerFixture):
workflow_entry = MagicMock()
workflow_entry.graph_engine = MagicMock()
workflow_entry.run.return_value = []
workflow_entry.run.side_effect = lambda: events.append("workflow_run") or []
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
runner.run()
assert events == ["session_enter", "session_exit", "workflow_run"]
runner._init_rag_pipeline_graph.assert_called_once()

View File

@ -5,6 +5,7 @@ from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel
from pytest_mock import MockerFixture
from sqlalchemy.dialects import postgresql
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
@ -16,6 +17,16 @@ class _Chunk(BaseModel):
value: int
def _build_app_model_config(result: dict | None = None):
app_model_config = MagicMock()
app_model_config.app_id = "app-1"
app_model_config.to_dict.return_value = result or {
"user_input_form": [{"name": "bar"}],
"annotation_reply": {"enabled": False},
}
return app_model_config
class TestBaseBackwardsInvocation:
def test_convert_to_event_stream_with_generator_and_error(self):
def _stream():
@ -42,12 +53,25 @@ class TestBaseBackwardsInvocation:
class TestPluginAppBackwardsInvocation:
def patch_create_session(self, mocker: MockerFixture, *, return_value=None, side_effect=None):
session = MagicMock()
if side_effect is not None:
session.scalar.side_effect = side_effect
else:
session.scalar.return_value = return_value
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
mocker.patch("core.plugin.backwards_invocation.app.create_session", return_value=session_ctx)
return session
def test_fetch_app_info_workflow_path(self, mocker: MockerFixture):
workflow = MagicMock()
workflow.features_dict = {"feature": "v"}
workflow.user_input_form.return_value = [{"name": "foo"}]
app = MagicMock(mode=AppMode.WORKFLOW, workflow=workflow)
app = MagicMock(mode=AppMode.WORKFLOW)
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=workflow)
mapper = mocker.patch(
"core.plugin.backwards_invocation.app.get_parameters_from_feature_dict",
return_value={"mapped": True},
@ -59,10 +83,10 @@ class TestPluginAppBackwardsInvocation:
mapper.assert_called_once_with(features_dict={"feature": "v"}, user_input_form=[{"name": "foo"}])
def test_fetch_app_info_model_config_path(self, mocker: MockerFixture):
model_config = MagicMock()
model_config.to_dict.return_value = {"user_input_form": [{"name": "bar"}], "k": "v"}
app = MagicMock(mode=AppMode.COMPLETION, app_model_config=model_config)
model_config_dict = {"user_input_form": [{"name": "bar"}], "k": "v"}
app = MagicMock(mode=AppMode.COMPLETION)
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app_model_config_dict", return_value=model_config_dict)
mocker.patch(
"core.plugin.backwards_invocation.app.get_parameters_from_feature_dict",
return_value={"mapped": True},
@ -85,8 +109,10 @@ class TestPluginAppBackwardsInvocation:
def test_invoke_app_routes_by_mode(self, mocker: MockerFixture, mode, route_method):
app = MagicMock(mode=mode)
user = MagicMock()
workflow = MagicMock()
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=user)
mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=workflow)
route = mocker.patch.object(PluginAppBackwardsInvocation, route_method, return_value={"routed": True})
result = PluginAppBackwardsInvocation.invoke_app(
@ -106,7 +132,9 @@ class TestPluginAppBackwardsInvocation:
def test_invoke_app_uses_end_user_when_user_id_missing(self, mocker: MockerFixture):
app = MagicMock(mode=AppMode.WORKFLOW)
end_user = MagicMock()
workflow = MagicMock()
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=workflow)
get_or_create = mocker.patch(
"core.plugin.backwards_invocation.app.EndUserService.get_or_create_end_user",
return_value=end_user,
@ -126,7 +154,8 @@ class TestPluginAppBackwardsInvocation:
assert result == {"ok": True}
get_or_create.assert_called_once_with(app)
assert route.call_args.args[1] is end_user
assert route.call_args.args[1] is workflow
assert route.call_args.args[2] is end_user
def test_invoke_app_missing_query_for_chat_raises(self, mocker: MockerFixture):
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=MagicMock(mode=AppMode.CHAT))
@ -190,7 +219,7 @@ class TestPluginAppBackwardsInvocation:
app = MagicMock()
app.mode = AppMode.ADVANCED_CHAT
app.workflow = workflow
mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=workflow)
mocker.patch(
"core.plugin.backwards_invocation.app.db",
@ -217,8 +246,9 @@ class TestPluginAppBackwardsInvocation:
assert isinstance(pause_state_config, PauseStateLayerConfig)
assert pause_state_config.state_owner_user_id == "owner-id"
def test_invoke_chat_app_advanced_chat_without_workflow_raises(self):
app = MagicMock(mode=AppMode.ADVANCED_CHAT, workflow=None)
def test_invoke_chat_app_advanced_chat_without_workflow_raises(self, mocker: MockerFixture):
app = MagicMock(mode=AppMode.ADVANCED_CHAT)
mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=None)
with pytest.raises(ValueError, match="unexpected app type"):
PluginAppBackwardsInvocation.invoke_chat_app(
app=app,
@ -249,7 +279,6 @@ class TestPluginAppBackwardsInvocation:
app = MagicMock()
app.mode = AppMode.WORKFLOW
app.workflow = workflow
mocker.patch(
"core.plugin.backwards_invocation.app.db",
@ -262,6 +291,7 @@ class TestPluginAppBackwardsInvocation:
result = PluginAppBackwardsInvocation.invoke_workflow_app(
app=app,
workflow=workflow,
user=MagicMock(),
stream=False,
inputs={"k": "v"},
@ -274,12 +304,18 @@ class TestPluginAppBackwardsInvocation:
assert isinstance(pause_state_config, PauseStateLayerConfig)
assert pause_state_config.state_owner_user_id == "owner-id"
def test_invoke_workflow_app_without_workflow_raises(self):
app = MagicMock(mode=AppMode.WORKFLOW, workflow=None)
def test_invoke_app_workflow_without_workflow_raises(self, mocker: MockerFixture):
app = MagicMock(mode=AppMode.WORKFLOW)
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=MagicMock())
mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=None)
with pytest.raises(ValueError, match="unexpected app type"):
PluginAppBackwardsInvocation.invoke_workflow_app(
app=app,
user=MagicMock(),
PluginAppBackwardsInvocation.invoke_app(
app_id="app",
user_id="user",
tenant_id="tenant",
conversation_id=None,
query=None,
stream=False,
inputs={},
files=[],
@ -297,58 +333,97 @@ class TestPluginAppBackwardsInvocation:
assert spy.call_count == 1
def test_get_user_returns_end_user(self, mocker: MockerFixture):
session = MagicMock()
session.scalar.side_effect = [MagicMock(id="end-user")]
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx)
mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock()))
session = self.patch_create_session(mocker, side_effect=[MagicMock(id="end-user")])
app = SimpleNamespace(id="app-1", tenant_id="tenant-1")
user = PluginAppBackwardsInvocation._get_user("uid", app)
user = PluginAppBackwardsInvocation._get_user("uid")
assert user.id == "end-user"
stmt = session.scalar.call_args_list[0].args[0]
compiled = str(stmt.compile(dialect=postgresql.dialect()))
assert "end_users.id" in compiled
assert "end_users.tenant_id" in compiled
assert "end_users.app_id" in compiled
assert stmt.compile().params == {"id_1": "uid", "tenant_id_1": "tenant-1", "app_id_1": "app-1"}
def test_get_user_falls_back_to_account_user(self, mocker: MockerFixture):
session = MagicMock()
session.scalar.side_effect = [None, MagicMock(id="account-user")]
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx)
mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock()))
session = self.patch_create_session(mocker, side_effect=[None, MagicMock(id="account-user")])
app = SimpleNamespace(id="app-1", tenant_id="tenant-1")
user = PluginAppBackwardsInvocation._get_user("uid", app)
user = PluginAppBackwardsInvocation._get_user("uid")
assert user.id == "account-user"
stmt = session.scalar.call_args_list[1].args[0]
compiled = str(stmt.compile(dialect=postgresql.dialect()))
assert "accounts.id" in compiled
assert "tenant_account_joins.account_id" in compiled
assert "tenant_account_joins.tenant_id" in compiled
assert stmt.compile().params == {"id_1": "uid", "tenant_id_1": "tenant-1"}
def test_get_user_raises_when_user_not_found(self, mocker: MockerFixture):
session = MagicMock()
session.scalar.side_effect = [None, None]
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx)
mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock()))
self.patch_create_session(mocker, side_effect=[None, None])
app = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with pytest.raises(ValueError, match="user not found"):
PluginAppBackwardsInvocation._get_user("uid")
PluginAppBackwardsInvocation._get_user("uid", app)
def test_get_app_returns_app(self, mocker: MockerFixture):
app_obj = MagicMock(id="app")
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj)))
mocker.patch("core.plugin.backwards_invocation.app.db", db)
self.patch_create_session(mocker, return_value=app_obj)
assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj
def test_get_app_raises_when_missing(self, mocker: MockerFixture):
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None)))
mocker.patch("core.plugin.backwards_invocation.app.db", db)
self.patch_create_session(mocker, return_value=None)
with pytest.raises(ValueError, match="app not found"):
PluginAppBackwardsInvocation._get_app("app", "tenant")
def test_get_app_raises_when_query_fails(self, mocker: MockerFixture):
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down"))))
mocker.patch("core.plugin.backwards_invocation.app.db", db)
self.patch_create_session(mocker, side_effect=RuntimeError("db down"))
with pytest.raises(ValueError, match="app not found"):
PluginAppBackwardsInvocation._get_app("app", "tenant")
def test_get_workflow_stays_inside_app_boundary(self, mocker: MockerFixture):
workflow = MagicMock(id="workflow")
session = self.patch_create_session(mocker, return_value=workflow)
app = SimpleNamespace(id="app-1", tenant_id="tenant-1", workflow_id="workflow-1")
assert PluginAppBackwardsInvocation._get_workflow(app) is workflow
stmt = session.scalar.call_args.args[0]
compiled = str(stmt.compile(dialect=postgresql.dialect()))
assert "workflows.id" in compiled
assert "workflows.tenant_id" in compiled
assert "workflows.app_id" in compiled
assert stmt.compile().params == {
"id_1": "workflow-1",
"tenant_id_1": "tenant-1",
"app_id_1": "app-1",
"param_1": 1,
}
def test_get_app_model_config_dict_uses_explicit_session_for_annotation_reply(self, mocker: MockerFixture):
annotation_reply = {"enabled": False}
app_model_config = _build_app_model_config()
session = self.patch_create_session(mocker, return_value=app_model_config)
load_annotation_reply_config = mocker.patch(
"core.plugin.backwards_invocation.app.load_annotation_reply_config",
return_value=annotation_reply,
)
app = SimpleNamespace(id="app-1", app_model_config_id="config-1")
result = PluginAppBackwardsInvocation._get_app_model_config_dict(app)
assert result is not None
assert result["user_input_form"] == [{"name": "bar"}]
assert result["annotation_reply"] == annotation_reply
load_annotation_reply_config.assert_called_once_with(session, "app-1")
app_model_config.to_dict.assert_called_once_with(annotation_reply=annotation_reply)
stmt = session.scalar.call_args.args[0]
compiled = str(stmt.compile(dialect=postgresql.dialect()))
assert "app_model_configs.id" in compiled
assert "app_model_configs.app_id" in compiled
assert stmt.compile().params == {"id_1": "config-1", "app_id_1": "app-1", "param_1": 1}

View File

@ -12,10 +12,11 @@ import json
from datetime import UTC, datetime
from decimal import Decimal
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, PropertyMock, patch
from uuid import uuid4
import pytest
from sqlalchemy.dialects import postgresql
from models.enums import ConversationFromSource
from models.model import (
@ -29,6 +30,7 @@ from models.model import (
Message,
MessageAnnotation,
Site,
load_annotation_reply_config,
)
@ -342,6 +344,70 @@ class TestAppModelConfig:
# Assert
assert result == questions
def test_to_dict_uses_injected_annotation_reply(self):
config = AppModelConfig(app_id=str(uuid4()))
annotation_reply = {"enabled": False}
with patch.object(
AppModelConfig,
"annotation_reply_dict",
new_callable=PropertyMock,
side_effect=AssertionError("annotation_reply_dict should not be accessed"),
):
result = config.to_dict(annotation_reply=annotation_reply)
assert result["annotation_reply"] == annotation_reply
class TestAnnotationReplyConfigLoader:
def test_load_annotation_reply_config_returns_disabled_when_setting_missing(self):
session = MagicMock()
session.scalar.return_value = None
result = load_annotation_reply_config(session, "app-1")
assert result == {"enabled": False}
session.scalar.assert_called_once()
stmt = session.scalar.call_args.args[0]
compiled = str(stmt.compile(dialect=postgresql.dialect()))
assert "app_annotation_settings.app_id" in compiled
assert stmt.compile().params == {"app_id_1": "app-1"}
def test_load_annotation_reply_config_returns_embedding_model(self):
session = MagicMock()
annotation_setting = SimpleNamespace(
id="annotation-1",
score_threshold=0.7,
collection_binding_id="binding-1",
)
collection_binding = SimpleNamespace(provider_name="provider", model_name="embedding")
session.scalar.side_effect = [annotation_setting, collection_binding]
result = load_annotation_reply_config(session, "app-1")
assert result == {
"id": "annotation-1",
"enabled": True,
"score_threshold": 0.7,
"embedding_model": {
"embedding_provider_name": "provider",
"embedding_model_name": "embedding",
},
}
assert session.scalar.call_count == 2
stmt = session.scalar.call_args_list[1].args[0]
compiled = str(stmt.compile(dialect=postgresql.dialect()))
assert "dataset_collection_bindings.id" in compiled
assert stmt.compile().params == {"id_1": "binding-1"}
def test_load_annotation_reply_config_raises_when_binding_missing(self):
session = MagicMock()
annotation_setting = SimpleNamespace(collection_binding_id="binding-1")
session.scalar.side_effect = [annotation_setting, None]
with pytest.raises(ValueError, match="Collection binding detail not found"):
load_annotation_reply_config(session, "app-1")
class TestConversationModel:
"""Test suite for Conversation model integrity."""