diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index a884a1c7f9..7b4cb98bd4 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -10,7 +10,7 @@ from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variable_loader import VariableLoader from graphon.variables.variables import Variable from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager @@ -363,7 +363,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): :return: List of conversation variables ready for use """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: existing_variables = self._load_existing_conversation_variables(session) if not existing_variables: @@ -376,7 +376,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # Convert to Variable objects for use in the workflow conversation_variables = [var.to_variable() for var in existing_variables] - session.commit() return cast(list[Variable], conversation_variables) def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 5203de225c..0ce9ddce9e 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -16,7 +16,7 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.nodes import BuiltinNodeTypes from graphon.runtime import GraphRuntimeState from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -328,13 +328,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): @contextmanager def _database_session(self): """Context manager for database sessions.""" - with Session(db.engine, expire_on_commit=False) as session: - try: - yield session - session.commit() - except Exception: - session.rollback() - raise + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: + yield session def _ensure_workflow_initialized(self): """Fluent validation for workflow state.""" diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 49af169e88..f1b8b08eaa 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -7,7 +7,7 @@ from typing import Union from graphon.entities import WorkflowStartReason from graphon.enums import WorkflowExecutionStatus from graphon.runtime import GraphRuntimeState -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager @@ -252,13 +252,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): @contextmanager def _database_session(self): """Context manager for database sessions.""" - with Session(db.engine, expire_on_commit=False) as session: - try: - yield session - session.commit() - except Exception: - session.rollback() - raise + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: + yield session def _ensure_workflow_initialized(self): """Fluent validation for workflow state.""" diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 182f1b767d..a454217768 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,6 +1,6 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage from sqlalchemy import update -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from configs import dify_config from core.entities.model_entities import ModelStatus @@ -73,7 +73,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL pool_type="paid", ) else: - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: stmt = ( update(Provider) .where( @@ -90,4 +90,3 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL ) ) session.execute(stmt) - session.commit() diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 9df78a7830..6bb177fe02 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -266,9 +266,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): event = message.event if isinstance(event, QueueErrorEvent): - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: err = self.handle_error(event=event, session=session, message_id=self._message_id) - session.commit() yield self.error_to_stream_response(err) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): @@ -288,10 +287,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): answer=output_moderation_answer ) - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # Save message self._save_message(session=session, trace_manager=trace_manager) - session.commit() message_end_resp = self._message_end_to_stream_response() yield message_end_resp elif isinstance(event, QueueRetrieverResourcesEvent): diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 061719d15a..1fb0dc6cf1 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -134,6 +134,7 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( + patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, @@ -150,7 +151,9 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): # Setup mocks - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_session_class.return_value.__enter__.return_value = MagicMock() mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() @@ -177,7 +180,6 @@ class TestAdvancedChatAppRunnerConversationVariables: # Note: Since we're mocking ConversationVariable.from_variable, # we can't directly check the id, but we can verify add_all was called assert mock_session.add_all.called, "Session add_all should have been called" - assert mock_session.commit.called, "Session commit should have been called" def test_no_variables_creates_all(self): """Test that all conversation variables are created when none exist in DB.""" @@ -278,6 +280,7 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( + patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, @@ -295,7 +298,9 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): # Setup mocks - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_session_class.return_value.__enter__.return_value = MagicMock() mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() @@ -326,7 +331,6 @@ class TestAdvancedChatAppRunnerConversationVariables: # Verify that all variables were created assert len(added_items) == 2, "Should have added both variables" assert mock_session.add_all.called, "Session add_all should have been called" - assert mock_session.commit.called, "Session commit should have been called" def test_all_variables_exist_no_changes(self): """Test that no changes are made when all variables already exist in DB.""" @@ -429,6 +433,7 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( + patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, @@ -445,7 +450,9 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): # Setup mocks - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_session_class.return_value.__enter__.return_value = MagicMock() mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() @@ -465,4 +472,3 @@ class TestAdvancedChatAppRunnerConversationVariables: # Verify that no variables were added assert not mock_session.add_all.called, "Session add_all should not have been called" - assert mock_session.commit.called, "Session commit should still be called" diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py index 079df0b4e6..5d8faee897 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -93,6 +93,16 @@ def _patch_common_run_deps(runner: AdvancedChatAppRunner): scalar=lambda *a, **k: MagicMock(), ), ), + sessionmaker=MagicMock( + return_value=MagicMock( + begin=MagicMock( + return_value=MagicMock( + __enter__=lambda s: MagicMock(scalars=MagicMock(return_value=MagicMock(all=lambda: []))), + __exit__=lambda *a, **k: False, + ), + ), + ), + ), select=MagicMock(), db=MagicMock(engine=MagicMock()), RedisChannel=MagicMock(), diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index dabd2594b4..d91bb85aee 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -2,6 +2,7 @@ from __future__ import annotations from contextlib import contextmanager from types import SimpleNamespace +from unittest.mock import MagicMock import pytest from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus @@ -610,33 +611,33 @@ class TestWorkflowGenerateTaskPipeline: def test_database_session_rolls_back_on_error(self, monkeypatch): pipeline = _make_pipeline() - calls = {"commit": 0, "rollback": 0} - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs + calls = {"enter": 0, "exit_exc": None} + class _BeginContext: def __enter__(self): - return self + calls["enter"] += 1 + return MagicMock() def __exit__(self, exc_type, exc, tb): + calls["exit_exc"] = exc_type return False - def commit(self): - calls["commit"] += 1 + class _Sessionmaker: + def __init__(self, *args, **kwargs): + pass - def rollback(self): - calls["rollback"] += 1 + def begin(self): + return _BeginContext() - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.sessionmaker", _Sessionmaker) monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) with pytest.raises(RuntimeError, match="db error"): with pipeline._database_session(): raise RuntimeError("db error") - assert calls["commit"] == 0 - assert calls["rollback"] == 1 + assert calls["enter"] == 1 + assert calls["exit_exc"] is RuntimeError def test_node_retry_and_started_handlers_cover_none_and_value(self): pipeline = _make_pipeline()