diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 1751b45d9b..30334f5da8 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -203,7 +203,7 @@ class WorkflowTool(Tool): Resolve user object in both HTTP and worker contexts. In HTTP context: dereference the current_user LocalProxy (can return Account or EndUser). - In worker context: load Account from database by user_id (only returns Account, never EndUser). + In worker context: load Account(knowledge pipeline) or EndUser(trigger) from database by user_id. Returns: Account | EndUser | None: The resolved user object, or None if resolution fails. @@ -224,24 +224,28 @@ class WorkflowTool(Tool): logger.warning("Failed to resolve user from request context: %s", e) return None - def _resolve_user_from_database(self, user_id: str) -> Account | None: + def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None: """ Resolve user from database (worker/Celery context). """ - user_stmt = select(Account).where(Account.id == user_id) - user = db.session.scalar(user_stmt) - if not user: - return None - tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id) tenant = db.session.scalar(tenant_stmt) if not tenant: return None - user.current_tenant = tenant + user_stmt = select(Account).where(Account.id == user_id) + user = db.session.scalar(user_stmt) + if user: + user.current_tenant = tenant + return user - return user + end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id) + end_user = db.session.scalar(end_user_stmt) + if end_user: + return end_user + + return None def _get_workflow(self, app_id: str, version: str) -> Workflow: """ diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 02bf8e82f1..5d180c7cbc 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import pytest from core.app.entities.app_invoke_entities import InvokeFrom @@ -214,3 +216,76 @@ def test_create_variable_message(): assert message.message.variable_name == var_name assert message.message.variable_value == var_value assert message.message.stream is False + + +def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.MonkeyPatch): + """Ensure worker context can resolve EndUser when Account is missing.""" + + class StubSession: + def __init__(self, results: list): + self.results = results + + def scalar(self, _stmt): + return self.results.pop(0) + + tenant = SimpleNamespace(id="tenant_id") + end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id") + db_stub = SimpleNamespace(session=StubSession([tenant, None, end_user])) + + monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) + + entity = ToolEntity( + identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), + parameters=[], + description=None, + has_runtime_parameters=False, + ) + runtime = ToolRuntime(tenant_id="tenant_id", invoke_from=InvokeFrom.SERVICE_API) + tool = WorkflowTool( + workflow_app_id="", + workflow_as_tool_id="", + version="1", + workflow_entities={}, + workflow_call_depth=1, + entity=entity, + runtime=runtime, + ) + + resolved_user = tool._resolve_user_from_database(user_id=end_user.id) + + assert resolved_user is end_user + + +def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pytest.MonkeyPatch): + """Return None if tenant cannot be found in worker context.""" + + class StubSession: + def __init__(self, results: list): + self.results = results + + def scalar(self, _stmt): + return self.results.pop(0) + + db_stub = SimpleNamespace(session=StubSession([None])) + monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) + + entity = ToolEntity( + identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), + parameters=[], + description=None, + has_runtime_parameters=False, + ) + runtime = ToolRuntime(tenant_id="missing_tenant", invoke_from=InvokeFrom.SERVICE_API) + tool = WorkflowTool( + workflow_app_id="", + workflow_as_tool_id="", + version="1", + workflow_entities={}, + workflow_call_depth=1, + entity=entity, + runtime=runtime, + ) + + resolved_user = tool._resolve_user_from_database(user_id="any") + + assert resolved_user is None