diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 5adf04611d..79da01863b 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -5,6 +5,7 @@ from typing import Any from sqlalchemy import select +from core.app.entities.app_invoke_entities import InvokeFrom from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime @@ -18,7 +19,8 @@ from core.tools.errors import ToolInvokeError from extensions.ext_database import db from factories.file_factory import build_from_mapping from libs.login import current_user -from models.model import App +from models.account import Account +from models.model import App, EndUser from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -79,11 +81,13 @@ class WorkflowTool(Tool): generator = WorkflowAppGenerator() assert self.runtime is not None assert self.runtime.invoke_from is not None - assert current_user is not None + user = self._resolve_user(user_id) + if user is None: + raise ToolInvokeError("workflow tool invoke missing user context") result = generator.generate( app_model=app, workflow=workflow, - user=current_user, + user=user, args={"inputs": tool_parameters, "files": files}, invoke_from=self.runtime.invoke_from, streaming=False, @@ -227,3 +231,26 @@ class WorkflowTool(Tool): elif transfer_method == FileTransferMethod.LOCAL_FILE: file_dict["upload_file_id"] = file_dict.get("related_id") return file_dict + + def _resolve_user(self, user_id: str) -> Account | EndUser | None: + runtime = self.runtime + try: + user_candidate = current_user + except RuntimeError: + user_candidate = None + + if user_candidate is not None and getattr(user_candidate, "is_authenticated", False): + return user_candidate + + if not user_id or runtime is None: + return None + + invoke_from = runtime.invoke_from + if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.PUBLISHED}: + end_user = ( + db.session.query(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == runtime.tenant_id).first() + ) + if end_user: + return end_user + + return db.session.query(Account).where(Account.id == user_id).first() diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 17e3ebeea0..e69248c291 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -40,9 +40,64 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel lambda *args, **kwargs: {"data": {"error": "oops"}}, ) monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) + monkeypatch.setattr( + WorkflowTool, + "_resolve_user", + lambda self, _user_id: type("DummyUser", (), {"id": _user_id, "is_authenticated": True})(), + raising=False, + ) with pytest.raises(ToolInvokeError) as exc_info: # WorkflowTool always returns a generator, so we need to iterate to # actually `run` the tool. list(tool.invoke("test_user", {})) assert exc_info.value.args == ("oops",) + + +def test_workflow_tool_falls_back_to_user_resolver_when_no_current_user(monkeypatch: pytest.MonkeyPatch): + entity = ToolEntity( + identity=ToolIdentity(author="tester", name="work", label=I18nObject(en_US="work"), provider="prv"), + parameters=[], + description=None, + has_runtime_parameters=False, + ) + runtime = ToolRuntime(tenant_id="tenant-id", invoke_from=InvokeFrom.SERVICE_API) + tool = WorkflowTool( + workflow_app_id="app-id", + workflow_as_tool_id="tool-id", + version="1", + workflow_entities={}, + workflow_call_depth=0, + entity=entity, + runtime=runtime, + ) + + # keep tool internals simple for the test + monkeypatch.setattr(tool, "_get_app", lambda *_args, **_kwargs: object()) + monkeypatch.setattr(tool, "_get_workflow", lambda *_args, **_kwargs: object()) + monkeypatch.setattr(tool, "_transform_args", lambda tool_parameters, **_: (tool_parameters, [])) + + captured: dict[str, str] = {} + + class DummyUser: + id = "dummy-user" + is_authenticated = True + + dummy_user = DummyUser() + + def fake_resolver(self, user_id: str): + captured["user_id"] = user_id + return dummy_user + + def fake_generate(self, *, user, **_kwargs): + assert user is dummy_user + return {"data": {"outputs": {}}} + + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", fake_generate) + monkeypatch.setattr("core.tools.workflow_as_tool.tool.current_user", None) + monkeypatch.setattr(WorkflowTool, "_resolve_user", fake_resolver, raising=False) + + result = list(tool.invoke("user-123", {})) + + assert captured["user_id"] == "user-123" + assert len(result) == 2 # text + json outputs