diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 19d03345a5..2aa12ef157 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -520,9 +520,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): worker_thread.start() # release database connection, because the following new thread operations may take a long time - with Session(bind=db.engine, expire_on_commit=False): - workflow = _refresh_model(workflow) - message = _refresh_model(message) + with Session(bind=db.engine, expire_on_commit=False) as session: + workflow = _refresh_model(session=session, model=workflow) + message = _refresh_model(session=session, model=message) assert message is not None # workflow_ = session.get(Workflow, workflow.id) # assert workflow_ is not None @@ -691,30 +691,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): @overload -def _refresh_model(model: Workflow, session: Session | None = None) -> Workflow: ... +def _refresh_model(*, session: Session | None = None, model: Workflow) -> Workflow: ... @overload -def _refresh_model(model: Message, session: Session | None = None) -> Message: ... +def _refresh_model(*, session: Session | None = None, model: Message) -> Message: ... -def _refresh_model(model: Any, session: Session | None = None) -> Any: - if session is not None and hasattr(session, "get"): - refresh_session = session - else: - refresh_session = Session(bind=db.engine, expire_on_commit=False) - - with refresh_session: - if isinstance(model, Workflow): - detached_workflow = refresh_session.get(Workflow, model.id) - assert detached_workflow is not None - return detached_workflow - - if isinstance(model, Message): - detached_message = refresh_session.get(Message, model.id) - assert detached_message is not None - return detached_message +def _refresh_model(*, session: Session | None = None, model: Any) -> Any: + if session is not None: + detached_model = session.get(type(model), model.id) + assert detached_model is not None + return detached_model + with Session(bind=db.engine, expire_on_commit=False) as refresh_session: detached_model = refresh_session.get(type(model), model.id) assert detached_model is not None return detached_model diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py index 441d2fcd17..cdb033111d 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py @@ -1013,7 +1013,7 @@ class TestAdvancedChatAppGeneratorInternals: monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object())) - refreshed = _refresh_model(session=SimpleNamespace(), model=source_model) + refreshed = _refresh_model(session=None, model=source_model) assert refreshed is detached_model