diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index a440b35035..a2ae8dec5b 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -21,6 +21,7 @@ from core.app.entities.queue_entities import ( ) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration from core.variables.variables import VariableUnion @@ -41,7 +42,7 @@ from models import Workflow from models.enums import UserFrom from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable -from services.conversation_variable_updater import conversation_variable_updater_factory +from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) @@ -202,7 +203,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ) workflow_entry.graph_engine.layer(persistence_layer) - conversation_variable_layer = ConversationVariablePersistenceLayer(conversation_variable_updater_factory()) + conversation_variable_layer = ConversationVariablePersistenceLayer( + ConversationVariableUpdater(session_factory.get_session_maker()) + ) workflow_entry.graph_engine.layer(conversation_variable_layer) for layer in self._graph_engine_layers: workflow_entry.graph_engine.layer(layer) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 038c9feb12..295d48d8a1 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -17,7 +17,7 @@ from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable from models.model import App, Conversation, EndUser, Message -from services.conversation_variable_updater import conversation_variable_updater_factory +from services.conversation_variable_updater import ConversationVariableUpdater from services.errors.conversation import ( ConversationNotExistsError, ConversationVariableNotExistsError, @@ -337,7 +337,7 @@ class ConversationService: updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict) # Use the conversation variable updater to persist the changes - updater = conversation_variable_updater_factory() + updater = ConversationVariableUpdater(session_factory.get_session_maker()) updater.update(conversation_id, updated_variable) updater.flush() diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index 2da507fac7..acc0ec2b22 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,8 +1,7 @@ from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.variables.variables import Variable -from extensions.ext_database import db from models import ConversationVariable @@ -10,12 +9,15 @@ class ConversationVariableNotFoundError(Exception): pass -class ConversationVariableUpdaterImpl: +class ConversationVariableUpdater: + def __init__(self, session_maker: sessionmaker[Session]) -> None: + self._session_maker: sessionmaker[Session] = session_maker + def update(self, conversation_id: str, variable: Variable) -> None: stmt = select(ConversationVariable).where( ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id ) - with Session(db.engine) as session: + with self._session_maker() as session: row = session.scalar(stmt) if not row: raise ConversationVariableNotFoundError("conversation variable not found in the database") @@ -24,7 +26,3 @@ class ConversationVariableUpdaterImpl: def flush(self) -> None: pass - - -def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl: - return ConversationVariableUpdaterImpl()