refactor(api): inject sessionmaker into conversation variable updater (#30609)

This commit is contained in:
-LAN- 2026-01-06 14:52:59 +08:00 committed by GitHub
parent f3ca8be9f9
commit d12b91a01a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 12 deletions

View File

@ -21,6 +21,7 @@ from core.app.entities.queue_entities import (
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
@ -41,7 +42,7 @@ from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
from services.conversation_variable_updater import conversation_variable_updater_factory
from services.conversation_variable_updater import ConversationVariableUpdater
logger = logging.getLogger(__name__)
@ -202,7 +203,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
conversation_variable_layer = ConversationVariablePersistenceLayer(conversation_variable_updater_factory())
conversation_variable_layer = ConversationVariablePersistenceLayer(
ConversationVariableUpdater(session_factory.get_session_maker())
)
workflow_entry.graph_engine.layer(conversation_variable_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)

View File

@ -17,7 +17,7 @@ from libs.datetime_utils import naive_utc_now
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account, ConversationVariable
from models.model import App, Conversation, EndUser, Message
from services.conversation_variable_updater import conversation_variable_updater_factory
from services.conversation_variable_updater import ConversationVariableUpdater
from services.errors.conversation import (
ConversationNotExistsError,
ConversationVariableNotExistsError,
@ -337,7 +337,7 @@ class ConversationService:
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
# Use the conversation variable updater to persist the changes
updater = conversation_variable_updater_factory()
updater = ConversationVariableUpdater(session_factory.get_session_maker())
updater.update(conversation_id, updated_variable)
updater.flush()

View File

@ -1,8 +1,7 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from core.variables.variables import Variable
from extensions.ext_database import db
from models import ConversationVariable
@ -10,12 +9,15 @@ class ConversationVariableNotFoundError(Exception):
pass
class ConversationVariableUpdaterImpl:
class ConversationVariableUpdater:
def __init__(self, session_maker: sessionmaker[Session]) -> None:
self._session_maker: sessionmaker[Session] = session_maker
def update(self, conversation_id: str, variable: Variable) -> None:
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
with self._session_maker() as session:
row = session.scalar(stmt)
if not row:
raise ConversationVariableNotFoundError("conversation variable not found in the database")
@ -24,7 +26,3 @@ class ConversationVariableUpdaterImpl:
def flush(self) -> None:
pass
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
return ConversationVariableUpdaterImpl()