refactor(api): use sessionmaker in core app generators & pipelines (#34771)

This commit is contained in:
carlos4s 2026-04-08 18:15:58 -05:00 committed by GitHub
parent 289f091bf9
commit 02a9f0abca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 49 additions and 46 deletions

View File

@ -10,7 +10,7 @@ from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.variable_loader import VariableLoader
from graphon.variables.variables import Variable
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -363,7 +363,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
:return: List of conversation variables ready for use
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
existing_variables = self._load_existing_conversation_variables(session)
if not existing_variables:
@ -376,7 +376,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# Convert to Variable objects for use in the workflow
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:

View File

@ -16,7 +16,7 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
from graphon.nodes import BuiltinNodeTypes
from graphon.runtime import GraphRuntimeState
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -328,13 +328,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
@contextmanager
def _database_session(self):
"""Context manager for database sessions."""
with Session(db.engine, expire_on_commit=False) as session:
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
yield session
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""

View File

@ -7,7 +7,7 @@ from typing import Union
from graphon.entities import WorkflowStartReason
from graphon.enums import WorkflowExecutionStatus
from graphon.runtime import GraphRuntimeState
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -252,13 +252,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
@contextmanager
def _database_session(self):
"""Context manager for database sessions."""
with Session(db.engine, expire_on_commit=False) as session:
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
yield session
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""

View File

@ -1,6 +1,6 @@
from graphon.model_runtime.entities.llm_entities import LLMUsage
from sqlalchemy import update
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.entities.model_entities import ModelStatus
@ -73,7 +73,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
pool_type="paid",
)
else:
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
stmt = (
update(Provider)
.where(
@ -90,4 +90,3 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
)
)
session.execute(stmt)
session.commit()

View File

@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -266,9 +266,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
event = message.event
if isinstance(event, QueueErrorEvent):
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
err = self.handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self.error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
@ -288,10 +287,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
answer=output_moderation_answer
)
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
# Save message
self._save_message(session=session, trace_manager=trace_manager)
session.commit()
message_end_resp = self._message_end_to_stream_response()
yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent):

View File

@ -134,6 +134,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
@ -150,7 +151,9 @@ class TestAdvancedChatAppRunnerConversationVariables:
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_session_class.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
@ -177,7 +180,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Note: Since we're mocking ConversationVariable.from_variable,
# we can't directly check the id, but we can verify add_all was called
assert mock_session.add_all.called, "Session add_all should have been called"
assert mock_session.commit.called, "Session commit should have been called"
def test_no_variables_creates_all(self):
"""Test that all conversation variables are created when none exist in DB."""
@ -278,6 +280,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
@ -295,7 +298,9 @@ class TestAdvancedChatAppRunnerConversationVariables:
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_session_class.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
@ -326,7 +331,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Verify that all variables were created
assert len(added_items) == 2, "Should have added both variables"
assert mock_session.add_all.called, "Session add_all should have been called"
assert mock_session.commit.called, "Session commit should have been called"
def test_all_variables_exist_no_changes(self):
"""Test that no changes are made when all variables already exist in DB."""
@ -429,6 +433,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
@ -445,7 +450,9 @@ class TestAdvancedChatAppRunnerConversationVariables:
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_session_class.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
@ -465,4 +472,3 @@ class TestAdvancedChatAppRunnerConversationVariables:
# Verify that no variables were added
assert not mock_session.add_all.called, "Session add_all should not have been called"
assert mock_session.commit.called, "Session commit should still be called"

View File

@ -93,6 +93,16 @@ def _patch_common_run_deps(runner: AdvancedChatAppRunner):
scalar=lambda *a, **k: MagicMock(),
),
),
sessionmaker=MagicMock(
return_value=MagicMock(
begin=MagicMock(
return_value=MagicMock(
__enter__=lambda s: MagicMock(scalars=MagicMock(return_value=MagicMock(all=lambda: []))),
__exit__=lambda *a, **k: False,
),
),
),
),
select=MagicMock(),
db=MagicMock(engine=MagicMock()),
RedisChannel=MagicMock(),

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus
@ -610,33 +611,33 @@ class TestWorkflowGenerateTaskPipeline:
def test_database_session_rolls_back_on_error(self, monkeypatch):
pipeline = _make_pipeline()
calls = {"commit": 0, "rollback": 0}
class _Session:
def __init__(self, *args, **kwargs):
_ = args, kwargs
calls = {"enter": 0, "exit_exc": None}
class _BeginContext:
def __enter__(self):
return self
calls["enter"] += 1
return MagicMock()
def __exit__(self, exc_type, exc, tb):
calls["exit_exc"] = exc_type
return False
def commit(self):
calls["commit"] += 1
class _Sessionmaker:
def __init__(self, *args, **kwargs):
pass
def rollback(self):
calls["rollback"] += 1
def begin(self):
return _BeginContext()
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session)
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.sessionmaker", _Sessionmaker)
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object()))
with pytest.raises(RuntimeError, match="db error"):
with pipeline._database_session():
raise RuntimeError("db error")
assert calls["commit"] == 0
assert calls["rollback"] == 1
assert calls["enter"] == 1
assert calls["exit_exc"] is RuntimeError
def test_node_retry_and_started_handlers_cover_none_and_value(self):
pipeline = _make_pipeline()