mirror of https://github.com/langgenius/dify.git
refactor(api): inject sessionmaker into conversation variable updater (#30609)
This commit is contained in:
parent
f3ca8be9f9
commit
d12b91a01a
|
|
@ -21,6 +21,7 @@ from core.app.entities.queue_entities import (
|
|||
)
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
|
|
@ -41,7 +42,7 @@ from models import Workflow
|
|||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.workflow import ConversationVariable
|
||||
from services.conversation_variable_updater import conversation_variable_updater_factory
|
||||
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -202,7 +203,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
conversation_variable_layer = ConversationVariablePersistenceLayer(conversation_variable_updater_factory())
|
||||
conversation_variable_layer = ConversationVariablePersistenceLayer(
|
||||
ConversationVariableUpdater(session_factory.get_session_maker())
|
||||
)
|
||||
workflow_entry.graph_engine.layer(conversation_variable_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from libs.datetime_utils import naive_utc_now
|
|||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account, ConversationVariable
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from services.conversation_variable_updater import conversation_variable_updater_factory
|
||||
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||
from services.errors.conversation import (
|
||||
ConversationNotExistsError,
|
||||
ConversationVariableNotExistsError,
|
||||
|
|
@ -337,7 +337,7 @@ class ConversationService:
|
|||
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
|
||||
|
||||
# Use the conversation variable updater to persist the changes
|
||||
updater = conversation_variable_updater_factory()
|
||||
updater = ConversationVariableUpdater(session_factory.get_session_maker())
|
||||
updater.update(conversation_id, updated_variable)
|
||||
updater.flush()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.variables.variables import Variable
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
|
||||
|
||||
|
|
@ -10,12 +9,15 @@ class ConversationVariableNotFoundError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class ConversationVariableUpdaterImpl:
|
||||
class ConversationVariableUpdater:
|
||||
def __init__(self, session_maker: sessionmaker[Session]) -> None:
|
||||
self._session_maker: sessionmaker[Session] = session_maker
|
||||
|
||||
def update(self, conversation_id: str, variable: Variable) -> None:
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
with self._session_maker() as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise ConversationVariableNotFoundError("conversation variable not found in the database")
|
||||
|
|
@ -24,7 +26,3 @@ class ConversationVariableUpdaterImpl:
|
|||
|
||||
def flush(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
|
||||
return ConversationVariableUpdaterImpl()
|
||||
|
|
|
|||
Loading…
Reference in New Issue