mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 19:27:23 +08:00
refactor(api): inject sessionmaker into conversation variable updater (#30609)
This commit is contained in:
parent
f3ca8be9f9
commit
d12b91a01a
@ -21,6 +21,7 @@ from core.app.entities.queue_entities import (
|
|||||||
)
|
)
|
||||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||||
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.moderation.base import ModerationError
|
from core.moderation.base import ModerationError
|
||||||
from core.moderation.input_moderation import InputModeration
|
from core.moderation.input_moderation import InputModeration
|
||||||
from core.variables.variables import VariableUnion
|
from core.variables.variables import VariableUnion
|
||||||
@ -41,7 +42,7 @@ from models import Workflow
|
|||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.model import App, Conversation, Message, MessageAnnotation
|
from models.model import App, Conversation, Message, MessageAnnotation
|
||||||
from models.workflow import ConversationVariable
|
from models.workflow import ConversationVariable
|
||||||
from services.conversation_variable_updater import conversation_variable_updater_factory
|
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -202,7 +203,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
workflow_entry.graph_engine.layer(persistence_layer)
|
workflow_entry.graph_engine.layer(persistence_layer)
|
||||||
conversation_variable_layer = ConversationVariablePersistenceLayer(conversation_variable_updater_factory())
|
conversation_variable_layer = ConversationVariablePersistenceLayer(
|
||||||
|
ConversationVariableUpdater(session_factory.get_session_maker())
|
||||||
|
)
|
||||||
workflow_entry.graph_engine.layer(conversation_variable_layer)
|
workflow_entry.graph_engine.layer(conversation_variable_layer)
|
||||||
for layer in self._graph_engine_layers:
|
for layer in self._graph_engine_layers:
|
||||||
workflow_entry.graph_engine.layer(layer)
|
workflow_entry.graph_engine.layer(layer)
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from libs.datetime_utils import naive_utc_now
|
|||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from models import Account, ConversationVariable
|
from models import Account, ConversationVariable
|
||||||
from models.model import App, Conversation, EndUser, Message
|
from models.model import App, Conversation, EndUser, Message
|
||||||
from services.conversation_variable_updater import conversation_variable_updater_factory
|
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||||
from services.errors.conversation import (
|
from services.errors.conversation import (
|
||||||
ConversationNotExistsError,
|
ConversationNotExistsError,
|
||||||
ConversationVariableNotExistsError,
|
ConversationVariableNotExistsError,
|
||||||
@ -337,7 +337,7 @@ class ConversationService:
|
|||||||
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
|
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
|
||||||
|
|
||||||
# Use the conversation variable updater to persist the changes
|
# Use the conversation variable updater to persist the changes
|
||||||
updater = conversation_variable_updater_factory()
|
updater = ConversationVariableUpdater(session_factory.get_session_maker())
|
||||||
updater.update(conversation_id, updated_variable)
|
updater.update(conversation_id, updated_variable)
|
||||||
updater.flush()
|
updater.flush()
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.variables.variables import Variable
|
from core.variables.variables import Variable
|
||||||
from extensions.ext_database import db
|
|
||||||
from models import ConversationVariable
|
from models import ConversationVariable
|
||||||
|
|
||||||
|
|
||||||
@ -10,12 +9,15 @@ class ConversationVariableNotFoundError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ConversationVariableUpdaterImpl:
|
class ConversationVariableUpdater:
|
||||||
|
def __init__(self, session_maker: sessionmaker[Session]) -> None:
|
||||||
|
self._session_maker: sessionmaker[Session] = session_maker
|
||||||
|
|
||||||
def update(self, conversation_id: str, variable: Variable) -> None:
|
def update(self, conversation_id: str, variable: Variable) -> None:
|
||||||
stmt = select(ConversationVariable).where(
|
stmt = select(ConversationVariable).where(
|
||||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||||
)
|
)
|
||||||
with Session(db.engine) as session:
|
with self._session_maker() as session:
|
||||||
row = session.scalar(stmt)
|
row = session.scalar(stmt)
|
||||||
if not row:
|
if not row:
|
||||||
raise ConversationVariableNotFoundError("conversation variable not found in the database")
|
raise ConversationVariableNotFoundError("conversation variable not found in the database")
|
||||||
@ -24,7 +26,3 @@ class ConversationVariableUpdaterImpl:
|
|||||||
|
|
||||||
def flush(self) -> None:
|
def flush(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
|
|
||||||
return ConversationVariableUpdaterImpl()
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user