mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
refactor(api): use sessionmaker in core app generators & pipelines (#34771)
This commit is contained in:
parent
289f091bf9
commit
02a9f0abca
@ -10,7 +10,7 @@ from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from graphon.variable_loader import VariableLoader
|
||||
from graphon.variables.variables import Variable
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -363,7 +363,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
:return: List of conversation variables ready for use
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
existing_variables = self._load_existing_conversation_variables(session)
|
||||
|
||||
if not existing_variables:
|
||||
@ -376,7 +376,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
# Convert to Variable objects for use in the workflow
|
||||
conversation_variables = [var.to_variable() for var in existing_variables]
|
||||
|
||||
session.commit()
|
||||
return cast(list[Variable], conversation_variables)
|
||||
|
||||
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
|
||||
|
||||
@ -16,7 +16,7 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -328,13 +328,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
@contextmanager
|
||||
def _database_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
yield session
|
||||
|
||||
def _ensure_workflow_initialized(self):
|
||||
"""Fluent validation for workflow state."""
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Union
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -252,13 +252,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
@contextmanager
|
||||
def _database_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
yield session
|
||||
|
||||
def _ensure_workflow_initialized(self):
|
||||
"""Fluent validation for workflow state."""
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import ModelStatus
|
||||
@ -73,7 +73,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
@ -90,4 +90,3 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -266,9 +266,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
err = self.handle_error(event=event, session=session, message_id=self._message_id)
|
||||
session.commit()
|
||||
yield self.error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
@ -288,10 +287,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
answer=output_moderation_answer
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
session.commit()
|
||||
message_end_resp = self._message_end_to_stream_response()
|
||||
yield message_end_resp
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
|
||||
@ -134,6 +134,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
@ -150,7 +151,9 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
@ -177,7 +180,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
# Note: Since we're mocking ConversationVariable.from_variable,
|
||||
# we can't directly check the id, but we can verify add_all was called
|
||||
assert mock_session.add_all.called, "Session add_all should have been called"
|
||||
assert mock_session.commit.called, "Session commit should have been called"
|
||||
|
||||
def test_no_variables_creates_all(self):
|
||||
"""Test that all conversation variables are created when none exist in DB."""
|
||||
@ -278,6 +280,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
@ -295,7 +298,9 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
@ -326,7 +331,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
# Verify that all variables were created
|
||||
assert len(added_items) == 2, "Should have added both variables"
|
||||
assert mock_session.add_all.called, "Session add_all should have been called"
|
||||
assert mock_session.commit.called, "Session commit should have been called"
|
||||
|
||||
def test_all_variables_exist_no_changes(self):
|
||||
"""Test that no changes are made when all variables already exist in DB."""
|
||||
@ -429,6 +433,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
@ -445,7 +450,9 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
@ -465,4 +472,3 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Verify that no variables were added
|
||||
assert not mock_session.add_all.called, "Session add_all should not have been called"
|
||||
assert mock_session.commit.called, "Session commit should still be called"
|
||||
|
||||
@ -93,6 +93,16 @@ def _patch_common_run_deps(runner: AdvancedChatAppRunner):
|
||||
scalar=lambda *a, **k: MagicMock(),
|
||||
),
|
||||
),
|
||||
sessionmaker=MagicMock(
|
||||
return_value=MagicMock(
|
||||
begin=MagicMock(
|
||||
return_value=MagicMock(
|
||||
__enter__=lambda s: MagicMock(scalars=MagicMock(return_value=MagicMock(all=lambda: []))),
|
||||
__exit__=lambda *a, **k: False,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
select=MagicMock(),
|
||||
db=MagicMock(engine=MagicMock()),
|
||||
RedisChannel=MagicMock(),
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus
|
||||
@ -610,33 +611,33 @@ class TestWorkflowGenerateTaskPipeline:
|
||||
|
||||
def test_database_session_rolls_back_on_error(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
calls = {"commit": 0, "rollback": 0}
|
||||
|
||||
class _Session:
|
||||
def __init__(self, *args, **kwargs):
|
||||
_ = args, kwargs
|
||||
calls = {"enter": 0, "exit_exc": None}
|
||||
|
||||
class _BeginContext:
|
||||
def __enter__(self):
|
||||
return self
|
||||
calls["enter"] += 1
|
||||
return MagicMock()
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
calls["exit_exc"] = exc_type
|
||||
return False
|
||||
|
||||
def commit(self):
|
||||
calls["commit"] += 1
|
||||
class _Sessionmaker:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def rollback(self):
|
||||
calls["rollback"] += 1
|
||||
def begin(self):
|
||||
return _BeginContext()
|
||||
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session)
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.sessionmaker", _Sessionmaker)
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object()))
|
||||
|
||||
with pytest.raises(RuntimeError, match="db error"):
|
||||
with pipeline._database_session():
|
||||
raise RuntimeError("db error")
|
||||
|
||||
assert calls["commit"] == 0
|
||||
assert calls["rollback"] == 1
|
||||
assert calls["enter"] == 1
|
||||
assert calls["exit_exc"] is RuntimeError
|
||||
|
||||
def test_node_retry_and_started_handlers_cover_none_and_value(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user