refactor(api): inject sessionmaker into conversation variable updater (#30609)

This commit is contained in:
-LAN- 2026-01-06 14:52:59 +08:00 committed by GitHub
parent f3ca8be9f9
commit d12b91a01a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 12 deletions

View File

@ -21,6 +21,7 @@ from core.app.entities.queue_entities import (
) )
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion from core.variables.variables import VariableUnion
@ -41,7 +42,7 @@ from models import Workflow
from models.enums import UserFrom from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable from models.workflow import ConversationVariable
from services.conversation_variable_updater import conversation_variable_updater_factory from services.conversation_variable_updater import ConversationVariableUpdater
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -202,7 +203,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
) )
workflow_entry.graph_engine.layer(persistence_layer) workflow_entry.graph_engine.layer(persistence_layer)
conversation_variable_layer = ConversationVariablePersistenceLayer(conversation_variable_updater_factory()) conversation_variable_layer = ConversationVariablePersistenceLayer(
ConversationVariableUpdater(session_factory.get_session_maker())
)
workflow_entry.graph_engine.layer(conversation_variable_layer) workflow_entry.graph_engine.layer(conversation_variable_layer)
for layer in self._graph_engine_layers: for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer) workflow_entry.graph_engine.layer(layer)

View File

@ -17,7 +17,7 @@ from libs.datetime_utils import naive_utc_now
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account, ConversationVariable from models import Account, ConversationVariable
from models.model import App, Conversation, EndUser, Message from models.model import App, Conversation, EndUser, Message
from services.conversation_variable_updater import conversation_variable_updater_factory from services.conversation_variable_updater import ConversationVariableUpdater
from services.errors.conversation import ( from services.errors.conversation import (
ConversationNotExistsError, ConversationNotExistsError,
ConversationVariableNotExistsError, ConversationVariableNotExistsError,
@ -337,7 +337,7 @@ class ConversationService:
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict) updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
# Use the conversation variable updater to persist the changes # Use the conversation variable updater to persist the changes
updater = conversation_variable_updater_factory() updater = ConversationVariableUpdater(session_factory.get_session_maker())
updater.update(conversation_id, updated_variable) updater.update(conversation_id, updated_variable)
updater.flush() updater.flush()

View File

@ -1,8 +1,7 @@
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from core.variables.variables import Variable from core.variables.variables import Variable
from extensions.ext_database import db
from models import ConversationVariable from models import ConversationVariable
@ -10,12 +9,15 @@ class ConversationVariableNotFoundError(Exception):
pass pass
class ConversationVariableUpdaterImpl: class ConversationVariableUpdater:
def __init__(self, session_maker: sessionmaker[Session]) -> None:
self._session_maker: sessionmaker[Session] = session_maker
def update(self, conversation_id: str, variable: Variable) -> None: def update(self, conversation_id: str, variable: Variable) -> None:
stmt = select(ConversationVariable).where( stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
) )
with Session(db.engine) as session: with self._session_maker() as session:
row = session.scalar(stmt) row = session.scalar(stmt)
if not row: if not row:
raise ConversationVariableNotFoundError("conversation variable not found in the database") raise ConversationVariableNotFoundError("conversation variable not found in the database")
@ -24,7 +26,3 @@ class ConversationVariableUpdaterImpl:
def flush(self) -> None: def flush(self) -> None:
pass pass
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
return ConversationVariableUpdaterImpl()