diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 81a1d54199..389db8a972 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -7,8 +7,8 @@ from typing import Any, cast from flask import has_request_context from sqlalchemy import select -from sqlalchemy.orm import Session +from core.db.session_factory import session_factory from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.tools.__base.tool import Tool @@ -20,7 +20,6 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.errors import ToolInvokeError -from extensions.ext_database import db from factories.file_factory import build_from_mapping from libs.login import current_user from models import Account, Tenant @@ -230,30 +229,32 @@ class WorkflowTool(Tool): """ Resolve user from database (worker/Celery context). """ + with session_factory.create_session() as session: + tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id) + tenant = session.scalar(tenant_stmt) + if not tenant: + return None + + user_stmt = select(Account).where(Account.id == user_id) + user = session.scalar(user_stmt) + if user: + user.current_tenant = tenant + session.expunge(user) + return user + + end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id) + end_user = session.scalar(end_user_stmt) + if end_user: + session.expunge(end_user) + return end_user - tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id) - tenant = db.session.scalar(tenant_stmt) - if not tenant: return None - user_stmt = select(Account).where(Account.id == user_id) - user = db.session.scalar(user_stmt) - if user: - user.current_tenant = tenant - return user - - end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id) - end_user = db.session.scalar(end_user_stmt) - if end_user: - return end_user - - return None - def _get_workflow(self, app_id: str, version: str) -> Workflow: """ get the workflow by app id and version """ - with Session(db.engine, expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): if not version: stmt = ( select(Workflow) @@ -265,22 +266,24 @@ class WorkflowTool(Tool): stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) workflow = session.scalar(stmt) - if not workflow: - raise ValueError("workflow not found or not published") + if not workflow: + raise ValueError("workflow not found or not published") - return workflow + session.expunge(workflow) + return workflow def _get_app(self, app_id: str) -> App: """ get the app by app id """ stmt = select(App).where(App.id == app_id) - with Session(db.engine, expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): app = session.scalar(stmt) - if not app: - raise ValueError("app not found") + if not app: + raise ValueError("app not found") - return app + session.expunge(app) + return app def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: """ diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 5d180c7cbc..cd45292488 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -228,11 +228,28 @@ def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.M def scalar(self, _stmt): return self.results.pop(0) + # SQLAlchemy Session APIs used by code under test + def expunge(self, *_args, **_kwargs): + pass + + def close(self): + pass + + # support `with session_factory.create_session() as session:` + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + tenant = SimpleNamespace(id="tenant_id") end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id") - db_stub = SimpleNamespace(session=StubSession([tenant, None, end_user])) - monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) + # Monkeypatch session factory to return our stub session + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: StubSession([tenant, None, end_user]), + ) entity = ToolEntity( identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), @@ -266,8 +283,23 @@ def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pyt def scalar(self, _stmt): return self.results.pop(0) - db_stub = SimpleNamespace(session=StubSession([None])) - monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) + def expunge(self, *_args, **_kwargs): + pass + + def close(self): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + + # Monkeypatch session factory to return our stub session with no tenant + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: StubSession([None]), + ) entity = ToolEntity( identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),