diff --git a/.agents/skills/e2e-cucumber-playwright/SKILL.md b/.agents/skills/e2e-cucumber-playwright/SKILL.md new file mode 100644 index 00000000000..de6b58f26df --- /dev/null +++ b/.agents/skills/e2e-cucumber-playwright/SKILL.md @@ -0,0 +1,79 @@ +--- +name: e2e-cucumber-playwright +description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository. +--- + +# Dify E2E Cucumber + Playwright + +Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite. + +## Scope + +- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`. +- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead. +- Do not use this skill for backend test or API review tasks under `api/`. + +## Read Order + +1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first. +2. Read only the files directly involved in the task: + - target `.feature` files under `e2e/features/` + - related step files under `e2e/features/step-definitions/` + - `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters + - `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter +3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved. +4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved. +5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern. + +## Local Rules + +- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer. +- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions. +- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps. +- Browser session behavior comes from `features/support/hooks.ts`: + - default: authenticated session with shared storage state + - `@unauthenticated`: clean browser context + - `@authenticated`: readability/selective-run tag only unless implementation changes + - `@fresh`: only for `e2e:full*` flows +- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture. + +## Workflow + +1. Rebuild local context. + - Inspect the target feature area. + - Reuse an existing step when wording and behavior already match. + - Add a new step only for a genuinely new user action or assertion. + - Keep edits close to the current capability folder unless the step is broadly reusable. +2. Write behavior-first scenarios. + - Describe user-observable behavior, not DOM mechanics. + - Keep each scenario focused on one workflow or outcome. + - Keep scenarios independent and re-runnable. +3. Write step definitions in the local style. + - Keep one step to one user-visible action or one assertion. + - Prefer Cucumber Expressions such as `{string}` and `{int}`. + - Scope locators to stable containers when the page has repeated elements. + - Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them. +4. Use Playwright in the local style. + - Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts. + - Use web-first `expect(...)` assertions. + - Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior. +5. Validate narrowly. + - Run the narrowest tagged scenario or flow that exercises the change. + - Run `pnpm -C e2e check`. + - Broaden verification only when the change affects hooks, tags, setup, or shared step semantics. + +## Review Checklist + +- Does the scenario describe behavior rather than implementation? +- Does it fit the current session model, tags, and `DifyWorld` usage? +- Should an existing step be reused instead of adding a new one? +- Are locators user-facing and assertions web-first? +- Does the change introduce hidden coupling across scenarios, tags, or instance state? +- Does it document or implement behavior that differs from the real hooks or configuration? + +Lead findings with correctness, flake risk, and architecture drift. + +## References + +- [`references/playwright-best-practices.md`](references/playwright-best-practices.md) +- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) diff --git a/.agents/skills/e2e-cucumber-playwright/agents/openai.yaml b/.agents/skills/e2e-cucumber-playwright/agents/openai.yaml new file mode 100644 index 00000000000..605cce041d9 --- /dev/null +++ b/.agents/skills/e2e-cucumber-playwright/agents/openai.yaml @@ -0,0 +1,4 @@ +interface: + display_name: "E2E Cucumber + Playwright" + short_description: "Write and review Dify E2E scenarios." + default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/." diff --git a/.agents/skills/e2e-cucumber-playwright/references/cucumber-best-practices.md b/.agents/skills/e2e-cucumber-playwright/references/cucumber-best-practices.md new file mode 100644 index 00000000000..d7a1a52852b --- /dev/null +++ b/.agents/skills/e2e-cucumber-playwright/references/cucumber-best-practices.md @@ -0,0 +1,93 @@ +# Cucumber Best Practices For Dify E2E + +Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite. + +Official sources: + +- https://cucumber.io/docs/guides/10-minute-tutorial/ +- https://cucumber.io/docs/cucumber/step-definitions/ +- https://cucumber.io/docs/cucumber/cucumber-expressions/ + +## What Matters Most + +### 1. Treat scenarios as executable specifications + +Cucumber scenarios should describe examples of behavior, not test implementation recipes. + +Apply it like this: + +- write what the user does and what should happen +- avoid UI-internal wording such as selector details, DOM structure, or component names +- keep language concrete enough that the scenario reads like living documentation + +### 2. Keep scenarios focused + +A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it. + +In Dify's suite, this means: + +- one capability-focused scenario per feature path +- no long setup chains when existing bootstrap or reusable steps already cover them +- no hidden dependency on another scenario's side effects + +### 3. Reuse steps, but only when behavior really matches + +Good reuse reduces duplication. Bad reuse hides meaning. + +Prefer reuse when: + +- the user action is genuinely the same +- the expected outcome is genuinely the same +- the wording stays natural across features + +Write a new step when: + +- the behavior is materially different +- reusing the old wording would make the scenario misleading +- a supposedly generic step would become an implementation-detail wrapper + +### 4. Prefer Cucumber Expressions + +Use Cucumber Expressions for parameters unless regex is clearly necessary. + +Common examples: + +- `{string}` for labels, names, and visible text +- `{int}` for counts +- `{float}` for decimal values +- `{word}` only when the value is truly a single token + +Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler. + +### 5. Keep step definitions thin and meaningful + +Step definitions are glue between Gherkin and automation, not a second abstraction language. + +For Dify: + +- type `this` as `DifyWorld` +- use `async function` +- keep each step to one user-visible action or assertion +- rely on `DifyWorld` and existing support code for shared context +- avoid leaking cross-scenario state + +### 6. Use tags intentionally + +Tags should communicate run scope or session semantics, not become ad hoc metadata. + +In Dify's current suite: + +- capability tags group related scenarios +- `@unauthenticated` changes session behavior +- `@authenticated` is descriptive/selective, not a behavior switch by itself +- `@fresh` belongs to reset/full-install flows only + +If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it. + +## Review Questions + +- Does the scenario read like a real example of product behavior? +- Are the steps behavior-oriented instead of implementation-oriented? +- Is a reused step still truthful in this feature? +- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement? +- Would a new reader understand the outcome without opening the step-definition file? diff --git a/.agents/skills/e2e-cucumber-playwright/references/playwright-best-practices.md b/.agents/skills/e2e-cucumber-playwright/references/playwright-best-practices.md new file mode 100644 index 00000000000..02e763d46b6 --- /dev/null +++ b/.agents/skills/e2e-cucumber-playwright/references/playwright-best-practices.md @@ -0,0 +1,96 @@ +# Playwright Best Practices For Dify E2E + +Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite. + +Official sources: + +- https://playwright.dev/docs/best-practices +- https://playwright.dev/docs/locators +- https://playwright.dev/docs/test-assertions +- https://playwright.dev/docs/browser-contexts + +## What Matters Most + +### 1. Keep scenarios isolated + +Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`. + +Apply it like this: + +- do not depend on another scenario having run first +- do not persist ad hoc scenario state outside `DifyWorld` +- do not couple ordinary scenarios to `@fresh` behavior +- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes + +### 2. Prefer user-facing locators + +Playwright recommends built-in locators that reflect what users perceive on the page. + +Preferred order in this repository: + +1. `getByRole` +2. `getByLabel` +3. `getByPlaceholder` +4. `getByText` +5. `getByTestId` when an explicit test contract is the most stable option + +Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical. + +Also remember: + +- repeated content usually needs scoping to a stable container +- exact text matching is often too brittle when role/name or label already exists +- `getByTestId` is acceptable when semantics are weak but the contract is intentional + +### 3. Use web-first assertions + +Playwright assertions auto-wait and retry. Prefer them over manual state inspection. + +Prefer: + +- `await expect(page).toHaveURL(...)` +- `await expect(locator).toBeVisible()` +- `await expect(locator).toBeHidden()` +- `await expect(locator).toBeEnabled()` +- `await expect(locator).toHaveText(...)` + +Avoid: + +- `expect(await locator.isVisible()).toBe(true)` +- custom polling loops for DOM state +- `waitForTimeout` as synchronization + +If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit. + +### 4. Let actions wait for actionability + +Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity. + +Good pattern: + +- assert a meaningful visible state when that is part of the behavior +- then click/fill/select via locator APIs + +Bad pattern: + +- stack arbitrary waits before every action +- wait on unstable implementation details instead of the visible state the user cares about + +### 5. Match debugging to the current suite + +Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures: + +- full-page screenshots +- page HTML +- console errors +- page errors + +Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling. + +## Review Questions + +- Would this locator survive DOM refactors that do not change user-visible behavior? +- Is this assertion using Playwright's retrying semantics? +- Is any explicit wait masking a real readiness problem? +- Does this code preserve per-scenario isolation? +- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model? diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 4da070bdbfe..105c979c585 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -200,7 +200,7 @@ When assigned to test a directory/path, test **ALL content** within that path: - ✅ **Import real project components** directly (including base components and siblings) - ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers -- ❌ **DO NOT mock** base components (`@/app/components/base/*`) +- ❌ **DO NOT mock** base components (`@/app/components/base/*`) or dify-ui primitives (`@langgenius/dify-ui/*`) - ❌ **DO NOT mock** sibling/child components in the same directory > See [Test Structure Template](#test-structure-template) for correct import/mock patterns. @@ -325,12 +325,12 @@ For more detailed information, refer to: ### Reference Examples in Codebase - `web/utils/classnames.spec.ts` - Utility function tests -- `web/app/components/base/button/index.spec.tsx` - Component tests +- `web/app/components/base/radio/__tests__/index.spec.tsx` - Component tests - `web/__mocks__/provider-context.ts` - Mock factory example ### Project Configuration -- `web/vitest.config.ts` - Vitest configuration +- `web/vite.config.ts` - Vite/Vitest configuration - `web/vitest.setup.ts` - Test environment setup - `web/scripts/analyze-component.js` - Component analysis tool - Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files. diff --git a/.agents/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md index 10b8fb66f9f..99258498dd1 100644 --- a/.agents/skills/frontend-testing/references/checklist.md +++ b/.agents/skills/frontend-testing/references/checklist.md @@ -36,7 +36,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen ### Integration vs Mocking -- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +- [ ] **DO NOT mock base components or dify-ui primitives** (base `Loading`, `Input`, `Badge`; dify-ui `Button`, `Tooltip`, `Dialog`, etc.) - [ ] Import real project components instead of mocking - [ ] Only mock: API calls, complex context providers, third-party libs with side effects - [ ] Prefer integration testing when using single spec file @@ -73,7 +73,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen ### Mocks -- [ ] **DO NOT mock base components** (`@/app/components/base/*`) +- [ ] **DO NOT mock base components or dify-ui primitives** (`@/app/components/base/*` or `@langgenius/dify-ui/*`) - [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`) - [ ] Shared mock state reset in `beforeEach` - [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md index f58377c4a5f..8c2f1c0c583 100644 --- a/.agents/skills/frontend-testing/references/mocking.md +++ b/.agents/skills/frontend-testing/references/mocking.md @@ -2,29 +2,27 @@ ## ⚠️ Important: What NOT to Mock -### DO NOT Mock Base Components +### DO NOT Mock Base Components or dify-ui Primitives -**Never mock components from `@/app/components/base/`** such as: +**Never mock components from `@/app/components/base/` or from `@langgenius/dify-ui/*`** such as: -- `Loading`, `Spinner` -- `Button`, `Input`, `Select` -- `Tooltip`, `Modal`, `Dropdown` -- `Icon`, `Badge`, `Tag` +- Legacy base (`@/app/components/base/*`): `Loading`, `Spinner`, `Input`, `Badge`, `Tag` +- dify-ui primitives (`@langgenius/dify-ui/*`): `Button`, `Tooltip`, `Dialog`, `Popover`, `DropdownMenu`, `ContextMenu`, `Select`, `AlertDialog`, `Toast` **Why?** -- Base components will have their own dedicated tests +- These components have their own dedicated tests - Mocking them creates false positives (tests pass but real integration fails) - Using real components tests actual integration behavior ```typescript -// ❌ WRONG: Don't mock base components +// ❌ WRONG: Don't mock base components or dify-ui primitives vi.mock('@/app/components/base/loading', () => () =>
+ You were mentioned in a workflow comment
+Hi {{ mentioned_name }},
+{{ commenter_name }} mentioned you in {{ app_name }}.
+Open {{ application_title }} to reply to the comment.
+
+ 你在工作流评论中被提及
+你好,{{ mentioned_name }}:
+{{ commenter_name }} 在 {{ app_name }} 中提及了你。
+{{ comment_content }}
+请在 {{ application_title }} 中查看并回复此评论。
+
+ You were mentioned in a workflow comment
+Hi {{ mentioned_name }},
+{{ commenter_name }} mentioned you in {{ app_name }}.
+{{ comment_content }}
+Open {{ application_title }} to reply to the comment.
+
+ 你在工作流评论中被提及
+你好,{{ mentioned_name }}:
+{{ commenter_name }} 在 {{ app_name }} 中提及了你。
+{{ comment_content }}
+请在 {{ application_title }} 中查看并回复此评论。
+Status: "{{ status }}"
-'''code block'''- -""" - inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"} - - result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs) - - # Verify the template rendered correctly with all special characters - output = result["result"] - assert 'value="TASK-123"' in output - assert "" in output - assert 'Status: "completed"' in output - assert "'''code block'''" in output - - -def test_jinja2_template_with_html_textarea_prefill(): - """ - Specific test for HTML textarea with Jinja2 variable pre-fill. - Verifies fix for issue #26818. - """ - template = "" - notes_content = "This is a multi-line note.\nWith special chars: 'single' and \"double\" quotes." - inputs = {"notes": notes_content} - - result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs) - - expected_output = f"" - assert result["result"] == expected_output - - -def test_jinja2_assemble_runner_script_encodes_template(): - """Test that assemble_runner_script properly base64 encodes the template.""" - template = "Hello {{ name }}!" - inputs = {"name": "World"} - - script = Jinja2TemplateTransformer.assemble_runner_script(template, inputs) - - # The template should be base64 encoded in the script - template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8") - assert template_b64 in script - # The raw template should NOT appear in the script (it's encoded) - assert "Hello {{ name }}!" not in script diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py deleted file mode 100644 index 25af312afa4..00000000000 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ /dev/null @@ -1,36 +0,0 @@ -from textwrap import dedent - -from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage -from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider -from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer - -CODE_LANGUAGE = CodeLanguage.PYTHON3 - - -def test_python3_plain(): - code = 'print("Hello World")' - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) - assert result == "Hello World\n" - - -def test_python3_json(): - code = dedent(""" - import json - print(json.dumps({'Hello': 'World'})) - """) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) - assert result == '{"Hello": "World"}\n' - - -def test_python3_with_code_template(): - result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"} - ) - assert result == {"result": "HelloWorld"} - - -def test_python3_get_runner_script(): - runner_script = Python3TemplateTransformer.get_runner_script() - assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1 - assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1 - assert runner_script.count(Python3TemplateTransformer._result_tag) == 2 diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 4f41396c22b..aaa60929935 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -2,17 +2,18 @@ import time import uuid import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.node_events import NodeRunResult -from graphon.nodes.code.code_node import CodeNode -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import NodeRunResult +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.entities import CodeNodeData +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params pytest_plugins = ("tests.integration_tests.workflow.nodes.__mock.code_executor",) @@ -64,8 +65,8 @@ def init_code_node(code_config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = CodeNode( - id=str(uuid.uuid4()), - config=code_config, + node_id=str(uuid.uuid4()), + config=CodeNodeData.model_validate(code_config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, code_executor=node_factory._code_executor, diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index b1f937e7388..b9f7b9575be 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -3,11 +3,6 @@ import uuid from urllib.parse import urlencode import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.graph import Graph -from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -16,6 +11,11 @@ from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params pytest_plugins = ("tests.integration_tests.workflow.nodes.__mock.http",) @@ -75,8 +75,8 @@ def init_http_node(config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, @@ -192,6 +192,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" + from core.workflow.system_variables import build_system_variables from graphon.enums import BuiltinNodeTypes from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, @@ -202,8 +203,6 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool - from core.workflow.system_variables import build_system_variables - # Create variable pool variable_pool = VariablePool( system_variables=build_system_variables(user_id="test", files=[]), @@ -724,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=graph_config["nodes"][1], + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index f0f3fcead19..3eead701639 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,20 +4,20 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.llm_generator.output_parser.structured_output import _parse_structured_output +from core.model_manager import ModelInstance +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import StreamCompletedEvent +from graphon.nodes.llm.entities import LLMNodeData from graphon.nodes.llm.file_saver import LLMFileSaver from graphon.nodes.llm.node import LLMNode from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.nodes.protocols import HttpClientProtocol from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.llm_generator.output_parser.structured_output import _parse_structured_output -from core.model_manager import ModelInstance -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" @@ -76,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode: llm_file_saver = MagicMock(spec=LLMFileSaver) node = LLMNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=LLMNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index fe512c25856..f2eabb86c34 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,17 +3,17 @@ import time import uuid from unittest.mock import MagicMock -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory -from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_runtime import DifyPromptMessageSerializer from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params @@ -70,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = ParameterExtractorNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ParameterExtractorNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 2d728569bee..e2e0723fb86 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,15 +1,15 @@ import time import uuid -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.template_rendering import TemplateRenderError - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.template_transform.entities import TemplateTransformNodeData +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params @@ -87,8 +87,8 @@ def test_execute_template_transform(): assert graph is not None node = TemplateTransformNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=TemplateTransformNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, jinja2_template_renderer=_SimpleJinja2Renderer(), diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 750ced7075e..a8e9422c1ee 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -2,18 +2,18 @@ import time import uuid from unittest.mock import MagicMock, patch -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.node_events import StreamCompletedEvent -from graphon.nodes.protocols import ToolFileManagerProtocol -from graphon.nodes.tool.tool_node import ToolNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import StreamCompletedEvent +from graphon.nodes.protocols import ToolFileManagerProtocol +from graphon.nodes.tool.entities import ToolNodeData +from graphon.nodes.tool.tool_node import ToolNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -61,8 +61,8 @@ def init_tool_node(config: dict): tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) node = ToolNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index ef74893f070..66a25e5dafb 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -369,7 +369,7 @@ def _create_app_with_containers() -> Flask: # Create and configure the Flask application logger.info("Initializing Flask application...") - app = create_app() + sio_app, app = create_app() logger.info("Flask application created successfully") # Initialize database schema diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index 0841217fcfe..18755ef012d 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -234,6 +234,35 @@ class TestAppEndpoints: } ) + def test_app_icon_post_should_forward_icon_type(self, app, monkeypatch): + api = app_module.AppIconApi() + method = _unwrap(api.post) + payload = { + "icon": "https://example.com/icon.png", + "icon_type": "image", + "icon_background": "#FFFFFF", + } + app_service = MagicMock() + app_service.update_app_icon.return_value = SimpleNamespace() + response_model = MagicMock() + response_model.model_dump.return_value = {"id": "app-1"} + + monkeypatch.setattr(app_module, "AppService", lambda: app_service) + monkeypatch.setattr(app_module.AppDetail, "model_validate", MagicMock(return_value=response_model)) + + with ( + app.test_request_context("/console/api/apps/app-1/icon", method="POST", json=payload), + patch.object(type(console_ns), "payload", payload), + ): + response = method(app_model=SimpleNamespace()) + + assert response == {"id": "app-1"} + assert app_service.update_app_icon.call_args.args[1:] == ( + payload["icon"], + payload["icon_background"], + app_module.IconType.IMAGE, + ) + class TestOpsTraceEndpoints: @pytest.fixture @@ -313,6 +342,21 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() + site.app_id = "app-1" + site.code = "test-code" + site.title = "My Site" + site.icon = None + site.icon_background = None + site.description = "Test site" + site.default_language = "en-US" + site.customize_domain = None + site.copyright = None + site.privacy_policy = None + site.custom_disclaimer = "" + site.customize_token_strategy = "not_allow" + site.prompt_public = False + site.show_workflow_steps = True + site.use_icon_as_answer_icon = False monkeypatch.setattr( site_module.db, "session", @@ -328,13 +372,29 @@ class TestSiteEndpoints: with app.test_request_context("/", json={"title": "My Site"}): result = method(app_model=SimpleNamespace(id="app-1")) - assert result is site + assert isinstance(result, dict) + assert result["title"] == "My Site" def test_app_site_access_token_reset(self, app, monkeypatch): api = site_module.AppSiteAccessTokenReset() method = _unwrap(api.post) site = MagicMock() + site.app_id = "app-1" + site.code = "old-code" + site.title = "My Site" + site.icon = None + site.icon_background = None + site.description = None + site.default_language = "en-US" + site.customize_domain = None + site.copyright = None + site.privacy_policy = None + site.custom_disclaimer = "" + site.customize_token_strategy = "not_allow" + site.prompt_public = False + site.show_workflow_steps = True + site.use_icon_as_answer_icon = False monkeypatch.setattr( site_module.db, "session", @@ -351,7 +411,8 @@ class TestSiteEndpoints: with app.test_request_context("/"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result is site + assert isinstance(result, dict) + assert result["access_token"] == "code" class TestWorkflowEndpoints: @@ -400,7 +461,7 @@ class TestWorkflowAppLogEndpoints: monkeypatch.setattr(workflow_app_log_module, "sessionmaker", DummySessionMaker) def fake_get_paginate(self, **_kwargs): - return {"items": [], "total": 0} + return {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} monkeypatch.setattr( workflow_app_log_module.WorkflowAppService, @@ -411,7 +472,7 @@ class TestWorkflowAppLogEndpoints: with app.test_request_context("/?page=1&limit=20"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result == {"items": [], "total": 0} + assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} class TestWorkflowDraftVariableEndpoints: @@ -555,7 +616,7 @@ class TestWorkflowTriggerEndpoints: trigger = MagicMock() session = MagicMock() - session.query.return_value.where.return_value.first.return_value = trigger + session.scalar.return_value = trigger class DummySessionCtx: def __enter__(self): @@ -576,7 +637,8 @@ class TestWorkflowTriggerEndpoints: with app.test_request_context("/?node_id=node-1"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result is trigger + assert isinstance(result, dict) + assert {"id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", "created_at"} <= set(result.keys()) class TestWrapsEndpoints: diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py index d8c6821f8db..25d19cf35a3 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py @@ -96,6 +96,56 @@ class TestAppImportApi: assert status == 200 assert response["status"] == ImportStatus.COMPLETED + def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + fake_session.commit.assert_called_once_with() + fake_session.rollback.assert_not_called() + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + fake_session.rollback.assert_called_once_with() + fake_session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + class TestAppImportConfirmApi: @pytest.fixture diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 5cc458fe2ef..5a22f81a697 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -4,15 +4,15 @@ import json import uuid from flask.testing import FlaskClient -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin -from models.account import AccountStatus, TenantAccountRole +from models.account import AccountStatus, TenantAccountRole, TenantStatus from models.enums import ConversationFromSource, CreatorUserRole from models.model import App, AppMode, Conversation, Message from models.workflow import WorkflowRun @@ -30,7 +30,7 @@ def _create_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: db_session.add(account) db_session.commit() - tenant = Tenant(name="Test Tenant", status="normal") + tenant = Tenant(name="Test Tenant", status=TenantStatus.NORMAL) db_session.add(tenant) db_session.commit() diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py new file mode 100644 index 00000000000..fad0b8b10e2 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -0,0 +1,73 @@ +from datetime import datetime +from unittest.mock import patch + +import pytest +from sqlalchemy.orm import Session +from werkzeug.exceptions import NotFound + +from controllers.console.app.conversation import _get_conversation +from models.enums import ConversationFromSource +from models.model import AppMode, Conversation +from tests.test_containers_integration_tests.controllers.console.helpers import ( + create_console_account_and_tenant, + create_console_app, +) + + +def test_get_conversation_mark_read_keeps_updated_at_unchanged( + db_session_with_containers: Session, +): + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + original_updated_at = datetime(2026, 2, 8, 0, 0, 0) + conversation = Conversation( + app_id=app.id, + name="read timestamp test", + inputs={}, + status="normal", + mode=AppMode.CHAT, + from_source=ConversationFromSource.CONSOLE, + from_account_id=account.id, + updated_at=original_updated_at, + ) + db_session_with_containers.add(conversation) + db_session_with_containers.commit() + + read_at = datetime(2026, 2, 9, 0, 0, 0) + + with ( + patch( + "controllers.console.app.conversation.current_account_with_tenant", + return_value=(account, tenant.id), + autospec=True, + ), + patch( + "controllers.console.app.conversation.naive_utc_now", + return_value=read_at, + autospec=True, + ), + ): + loaded = _get_conversation(app, conversation.id) + + db_session_with_containers.refresh(conversation) + + assert loaded.id == conversation.id + assert conversation.read_at == read_at + assert conversation.read_account_id == account.id + assert conversation.updated_at == original_updated_at + + +def test_get_conversation_raises_not_found_for_missing_conversation( + db_session_with_containers: Session, +): + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch( + "controllers.console.app.conversation.current_account_with_tenant", + return_value=(account, tenant.id), + autospec=True, + ): + with pytest.raises(NotFound): + _get_conversation(app, "00000000-0000-0000-0000-000000000000") diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py index 6b51ec98bcd..eff6dd789d7 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py @@ -148,14 +148,18 @@ def test_chat_message_list_success( account.id, created_at_offset_seconds=1, ) + # Capture IDs before the HTTP request detaches ORM instances from the session + app_id = app.id + conversation_id = conversation.id + second_id = second.id with patch( "controllers.console.app.message.attach_message_extra_contents", side_effect=_attach_message_extra_contents, ): response = test_client_with_containers.get( - f"/console/api/apps/{app.id}/chat-messages", - query_string={"conversation_id": conversation.id, "limit": 1}, + f"/console/api/apps/{app_id}/chat-messages", + query_string={"conversation_id": conversation_id, "limit": 1}, headers=authenticate_console_client(test_client_with_containers, account), ) @@ -165,7 +169,7 @@ def test_chat_message_list_success( assert payload["limit"] == 1 assert payload["has_more"] is True assert len(payload["data"]) == 1 - assert payload["data"][0]["id"] == second.id + assert payload["data"][0]["id"] == second_id def test_message_feedback_not_found( diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py index 8ddf867370e..290be876974 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -3,12 +3,12 @@ import uuid from flask.testing import FlaskClient -from graphon.variables.segments import StringSegment from sqlalchemy import select from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from factories.variable_factory import segment_to_variable +from graphon.variables.segments import StringSegment from models import Workflow from models.model import AppMode from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 879c337319c..320da85b601 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -158,7 +158,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 7b7393dade9..d2703ed5cce 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -113,12 +113,14 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") def test_reset_fetches_account_with_original_email( self, mock_get_reset_data, mock_revoke_token, + mock_db, mock_get_account, mock_update_account, app, @@ -126,6 +128,7 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} mock_account = MagicMock() mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account wraps_features = SimpleNamespace(enable_email_password_login=True) with ( @@ -161,7 +164,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index a2f1328579f..1eabb45422a 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -437,7 +437,10 @@ class TestAccountGeneration: second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 8f9db287e31..50249bcd743 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -335,10 +335,12 @@ class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants") def test_reset_password_success( self, mock_get_tenants, + mock_db, mock_get_account, mock_revoke_token, mock_get_data, @@ -356,6 +358,7 @@ class TestForgotPasswordResetApi: # Arrange mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] # Act diff --git a/api/tests/test_containers_integration_tests/controllers/console/helpers.py b/api/tests/test_containers_integration_tests/controllers/console/helpers.py index 9e2084f3939..a8ecf94da19 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/helpers.py +++ b/api/tests/test_containers_integration_tests/controllers/console/helpers.py @@ -11,7 +11,7 @@ from constants import HEADER_NAME_CSRF_TOKEN from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin -from models.account import AccountStatus, TenantAccountRole +from models.account import AccountStatus, TenantAccountRole, TenantStatus from models.model import App, AppMode from services.account_service import AccountService @@ -37,7 +37,7 @@ def create_console_account_and_tenant(db_session: Session) -> tuple[Account, Ten db_session.add(account) db_session.commit() - tenant = Tenant(name="Test Tenant", status="normal") + tenant = Tenant(name="Test Tenant", status=TenantStatus.NORMAL) db_session.add(tenant) db_session.commit() diff --git a/api/tests/test_containers_integration_tests/controllers/mcp/test_mcp.py b/api/tests/test_containers_integration_tests/controllers/mcp/test_mcp.py index 90670a9db59..21b395a04c6 100644 --- a/api/tests/test_containers_integration_tests/controllers/mcp/test_mcp.py +++ b/api/tests/test_containers_integration_tests/controllers/mcp/test_mcp.py @@ -444,7 +444,7 @@ class TestMCPAppApi: ) session = MagicMock() - session.query().where().first.side_effect = [server, app] + session.scalar.side_effect = [server, app] result_server, result_app = api._get_mcp_server_and_app("server-1", session) diff --git a/api/tests/integration_tests/vdb/iris/__init__.py b/api/tests/test_containers_integration_tests/controllers/service_api/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/iris/__init__.py rename to api/tests/test_containers_integration_tests/controllers/service_api/__init__.py diff --git a/api/tests/integration_tests/vdb/lindorm/__init__.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/lindorm/__init__.py rename to api/tests/test_containers_integration_tests/controllers/service_api/dataset/__init__.py diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py similarity index 50% rename from api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py rename to api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py index 910d781cd02..9b913d6d3d6 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py @@ -1,17 +1,16 @@ """ -Unit tests for Service API Dataset controllers. +Integration tests for Service API Dataset controllers. + +Migrated from unit_tests/controllers/service_api/dataset/test_dataset.py. Tests coverage for: - DatasetCreatePayload, DatasetUpdatePayload Pydantic models - Tag-related payloads (create, update, delete, binding) - DatasetListQuery model -- DatasetService and TagService interfaces -- Permission validation patterns +- API endpoint error handling and controller behavior -Focus on: -- Pydantic model validation -- Error type mappings -- Service method interfaces +Services (DatasetService, TagService, DocumentService) remain mocked +since these test controller-level behavior. """ import uuid @@ -19,6 +18,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound import services @@ -36,22 +36,23 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName from models.account import Account from models.dataset import DatasetPermissionEnum from models.enums import TagType -from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService -from services.tag_service import TagService +from models.model import Tag + +# --------------------------------------------------------------------------- +# Pydantic model validation tests +# --------------------------------------------------------------------------- class TestDatasetCreatePayload: """Test suite for DatasetCreatePayload Pydantic model.""" def test_payload_with_required_name(self): - """Test payload with required name field.""" payload = DatasetCreatePayload(name="Test Dataset") assert payload.name == "Test Dataset" assert payload.description == "" assert payload.permission == DatasetPermissionEnum.ONLY_ME def test_payload_with_all_fields(self): - """Test payload with all fields populated.""" payload = DatasetCreatePayload( name="Full Dataset", description="A comprehensive dataset description", @@ -70,28 +71,23 @@ class TestDatasetCreatePayload: assert payload.embedding_model_provider == "openai" def test_payload_name_length_validation_min(self): - """Test name minimum length validation.""" with pytest.raises(ValueError): DatasetCreatePayload(name="") def test_payload_name_length_validation_max(self): - """Test name maximum length validation (40 chars).""" with pytest.raises(ValueError): DatasetCreatePayload(name="A" * 41) def test_payload_description_max_length(self): - """Test description maximum length (400 chars).""" with pytest.raises(ValueError): DatasetCreatePayload(name="Dataset", description="A" * 401) @pytest.mark.parametrize("technique", ["high_quality", "economy"]) def test_payload_valid_indexing_techniques(self, technique): - """Test valid indexing technique values.""" payload = DatasetCreatePayload(name="Dataset", indexing_technique=technique) assert payload.indexing_technique == technique def test_payload_with_external_knowledge_settings(self): - """Test payload with external knowledge configuration.""" payload = DatasetCreatePayload( name="External Dataset", external_knowledge_api_id="api_123", external_knowledge_id="knowledge_456" ) @@ -103,20 +99,17 @@ class TestDatasetUpdatePayload: """Test suite for DatasetUpdatePayload Pydantic model.""" def test_payload_all_optional(self): - """Test payload with all fields optional.""" payload = DatasetUpdatePayload() assert payload.name is None assert payload.description is None assert payload.permission is None def test_payload_with_partial_update(self): - """Test payload with partial update fields.""" payload = DatasetUpdatePayload(name="Updated Name", description="Updated description") assert payload.name == "Updated Name" assert payload.description == "Updated description" def test_payload_with_permission_change(self): - """Test payload with permission update.""" payload = DatasetUpdatePayload( permission=DatasetPermissionEnum.PARTIAL_TEAM, partial_member_list=[{"user_id": "user_123", "role": "editor"}], @@ -125,12 +118,8 @@ class TestDatasetUpdatePayload: assert len(payload.partial_member_list) == 1 def test_payload_name_length_validation(self): - """Test name length constraints.""" - # Minimum is 1 with pytest.raises(ValueError): DatasetUpdatePayload(name="") - - # Maximum is 40 with pytest.raises(ValueError): DatasetUpdatePayload(name="A" * 41) @@ -139,7 +128,6 @@ class TestDatasetListQuery: """Test suite for DatasetListQuery Pydantic model.""" def test_query_with_defaults(self): - """Test query with default values.""" query = DatasetListQuery() assert query.page == 1 assert query.limit == 20 @@ -148,7 +136,6 @@ class TestDatasetListQuery: assert query.tag_ids == [] def test_query_with_all_filters(self): - """Test query with all filter fields.""" query = DatasetListQuery( page=3, limit=50, keyword="machine learning", include_all=True, tag_ids=["tag1", "tag2", "tag3"] ) @@ -159,7 +146,6 @@ class TestDatasetListQuery: assert len(query.tag_ids) == 3 def test_query_with_tag_filter(self): - """Test query with tag IDs filter.""" query = DatasetListQuery(tag_ids=["tag_abc", "tag_def"]) assert query.tag_ids == ["tag_abc", "tag_def"] @@ -168,22 +154,18 @@ class TestTagCreatePayload: """Test suite for TagCreatePayload Pydantic model.""" def test_payload_with_name(self): - """Test payload with required name.""" payload = TagCreatePayload(name="New Tag") assert payload.name == "New Tag" def test_payload_name_length_min(self): - """Test name minimum length (1).""" with pytest.raises(ValueError): TagCreatePayload(name="") def test_payload_name_length_max(self): - """Test name maximum length (50).""" with pytest.raises(ValueError): TagCreatePayload(name="A" * 51) def test_payload_with_unicode_name(self): - """Test payload with unicode characters.""" payload = TagCreatePayload(name="标签 🏷️ Тег") assert payload.name == "标签 🏷️ Тег" @@ -192,13 +174,11 @@ class TestTagUpdatePayload: """Test suite for TagUpdatePayload Pydantic model.""" def test_payload_with_name_and_id(self): - """Test payload with name and tag_id.""" payload = TagUpdatePayload(name="Updated Tag", tag_id="tag_123") assert payload.name == "Updated Tag" assert payload.tag_id == "tag_123" def test_payload_requires_tag_id(self): - """Test that tag_id is required.""" with pytest.raises(ValueError): TagUpdatePayload(name="Updated Tag") @@ -207,12 +187,10 @@ class TestTagDeletePayload: """Test suite for TagDeletePayload Pydantic model.""" def test_payload_with_tag_id(self): - """Test payload with tag_id.""" payload = TagDeletePayload(tag_id="tag_to_delete") assert payload.tag_id == "tag_to_delete" def test_payload_requires_tag_id(self): - """Test that tag_id is required.""" with pytest.raises(ValueError): TagDeletePayload() @@ -221,19 +199,16 @@ class TestTagBindingPayload: """Test suite for TagBindingPayload Pydantic model.""" def test_payload_with_valid_data(self): - """Test payload with valid binding data.""" payload = TagBindingPayload(tag_ids=["tag1", "tag2"], target_id="dataset_123") assert len(payload.tag_ids) == 2 assert payload.target_id == "dataset_123" def test_payload_rejects_empty_tag_ids(self): - """Test that empty tag_ids are rejected.""" with pytest.raises(ValueError) as exc_info: TagBindingPayload(tag_ids=[], target_id="dataset_123") assert "Tag IDs is required" in str(exc_info.value) def test_payload_single_tag_id(self): - """Test payload with single tag ID.""" payload = TagBindingPayload(tag_ids=["single_tag"], target_id="dataset_456") assert payload.tag_ids == ["single_tag"] @@ -242,674 +217,14 @@ class TestTagUnbindingPayload: """Test suite for TagUnbindingPayload Pydantic model.""" def test_payload_with_valid_data(self): - """Test payload with valid unbinding data.""" payload = TagUnbindingPayload(tag_id="tag_123", target_id="dataset_456") assert payload.tag_id == "tag_123" assert payload.target_id == "dataset_456" -class TestDatasetTagsApi: - """Test suite for DatasetTagsApi endpoints.""" - - @pytest.fixture - def app(self): - """Create Flask test application.""" - from flask import Flask - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.service_api.dataset.dataset.current_user") - @patch("controllers.service_api.dataset.dataset.TagService") - def test_get_tags_success(self, mock_tag_service, mock_current_user, app): - """Test successful retrieval of dataset tags.""" - # Arrange - mock_current_user needs to pass isinstance(current_user, Account) - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.current_tenant_id = "tenant_123" - # Replace the mock with our properly specced one - from controllers.service_api.dataset import dataset as dataset_module - - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag = Mock() - mock_tag.id = "tag_1" - mock_tag.name = "Test Tag" - mock_tag.type = TagType.KNOWLEDGE - mock_tag.binding_count = "0" # Required for Pydantic validation - must be string - mock_tag_service.get_tags.return_value = [mock_tag] - - from controllers.service_api.dataset.dataset import DatasetTagsApi - - try: - # Act - with app.test_request_context("/", method="GET"): - api = DatasetTagsApi() - response, status_code = api.get("tenant_123") - - # Assert - assert status_code == 200 - assert len(response) == 1 - assert response[0]["id"] == "tag_1" - assert response[0]["name"] == "Test Tag" - mock_tag_service.get_tags.assert_called_once_with("knowledge", "tenant_123") - finally: - dataset_module.current_user = original_current_user - - @pytest.mark.skip(reason="Production code bug: binding_count should be string, not integer") - @patch("controllers.service_api.dataset.dataset.TagService") - @patch("controllers.service_api.dataset.dataset.service_api_ns") - def test_create_tag_success(self, mock_service_api_ns, mock_tag_service, app): - """Test successful creation of a dataset tag.""" - # Arrange - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = True - mock_account.is_dataset_editor = False - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag = Mock() - mock_tag.id = "new_tag_1" - mock_tag.name = "New Tag" - mock_tag.type = TagType.KNOWLEDGE - mock_tag_service.save_tags.return_value = mock_tag - mock_service_api_ns.payload = {"name": "New Tag"} - - from controllers.service_api.dataset.dataset import DatasetTagsApi - - try: - # Act - with app.test_request_context("/", method="POST", json={"name": "New Tag"}): - api = DatasetTagsApi() - response, status_code = api.post("tenant_123") - - # Assert - assert status_code == 200 - assert response["id"] == "new_tag_1" - assert response["name"] == "New Tag" - assert response["binding_count"] == 0 - finally: - dataset_module.current_user = original_current_user - - def test_create_tag_forbidden(self, app): - """Test tag creation without edit permissions.""" - # Arrange - from werkzeug.exceptions import Forbidden - - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = False - mock_account.is_dataset_editor = False - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - from controllers.service_api.dataset.dataset import DatasetTagsApi - - try: - # Act & Assert - with app.test_request_context("/", method="POST"): - api = DatasetTagsApi() - with pytest.raises(Forbidden): - api.post("tenant_123") - finally: - dataset_module.current_user = original_current_user - - @pytest.mark.skip(reason="Production code bug: binding_count should be string, not integer") - @patch("controllers.service_api.dataset.dataset.TagService") - @patch("controllers.service_api.dataset.dataset.service_api_ns") - def test_update_tag_success(self, mock_service_api_ns, mock_tag_service, app): - """Test successful update of a dataset tag.""" - # Arrange - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = True - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag = Mock() - mock_tag.id = "tag_1" - mock_tag.name = "Updated Tag" - mock_tag.type = TagType.KNOWLEDGE - mock_tag.binding_count = "5" - mock_tag_service.update_tags.return_value = mock_tag - mock_tag_service.get_tag_binding_count.return_value = 5 - mock_service_api_ns.payload = {"name": "Updated Tag", "tag_id": "tag_1"} - - from controllers.service_api.dataset.dataset import DatasetTagsApi - - try: - # Act - with app.test_request_context("/", method="PATCH", json={"name": "Updated Tag", "tag_id": "tag_1"}): - api = DatasetTagsApi() - response, status_code = api.patch("tenant_123") - - # Assert - assert status_code == 200 - assert response["id"] == "tag_1" - assert response["name"] == "Updated Tag" - assert response["binding_count"] == 5 - finally: - dataset_module.current_user = original_current_user - - @pytest.mark.skip(reason="Production code bug: binding_count should be string, not integer") - @patch("controllers.service_api.dataset.dataset.TagService") - @patch("controllers.service_api.dataset.dataset.service_api_ns") - def test_delete_tag_success(self, mock_service_api_ns, mock_tag_service, app): - """Test successful deletion of a dataset tag.""" - # Arrange - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = True - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag_service.delete_tag.return_value = None - mock_service_api_ns.payload = {"tag_id": "tag_1"} - - from controllers.service_api.dataset.dataset import DatasetTagsApi - - try: - # Act - with app.test_request_context("/", method="DELETE", json={"tag_id": "tag_1"}): - api = DatasetTagsApi() - response = api.delete("tenant_123") - - # Assert - assert response == ("", 204) - mock_tag_service.delete_tag.assert_called_once_with("tag_1") - finally: - dataset_module.current_user = original_current_user - - -class TestDatasetTagBindingApi: - """Test suite for DatasetTagBindingApi endpoints.""" - - @pytest.fixture - def app(self): - """Create Flask test application.""" - from flask import Flask - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.service_api.dataset.dataset.TagService") - @patch("controllers.service_api.dataset.dataset.service_api_ns") - def test_bind_tags_success(self, mock_service_api_ns, mock_tag_service, app): - """Test successful binding of tags to dataset.""" - # Arrange - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = True - mock_account.is_dataset_editor = False - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag_service.save_tag_binding.return_value = None - payload = {"tag_ids": ["tag_1", "tag_2"], "target_id": "dataset_123"} - mock_service_api_ns.payload = payload - - from controllers.service_api.dataset.dataset import DatasetTagBindingApi - - try: - # Act - with app.test_request_context("/", method="POST", json=payload): - api = DatasetTagBindingApi() - response = api.post("tenant_123") - - # Assert - assert response == ("", 204) - mock_tag_service.save_tag_binding.assert_called_once_with( - {"tag_ids": ["tag_1", "tag_2"], "target_id": "dataset_123", "type": "knowledge"} - ) - finally: - dataset_module.current_user = original_current_user - - def test_bind_tags_forbidden(self, app): - """Test tag binding without edit permissions.""" - # Arrange - from werkzeug.exceptions import Forbidden - - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = False - mock_account.is_dataset_editor = False - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - from controllers.service_api.dataset.dataset import DatasetTagBindingApi - - try: - # Act & Assert - with app.test_request_context("/", method="POST"): - api = DatasetTagBindingApi() - with pytest.raises(Forbidden): - api.post("tenant_123") - finally: - dataset_module.current_user = original_current_user - - -class TestDatasetTagUnbindingApi: - """Test suite for DatasetTagUnbindingApi endpoints.""" - - @pytest.fixture - def app(self): - """Create Flask test application.""" - from flask import Flask - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.service_api.dataset.dataset.TagService") - @patch("controllers.service_api.dataset.dataset.service_api_ns") - def test_unbind_tag_success(self, mock_service_api_ns, mock_tag_service, app): - """Test successful unbinding of tag from dataset.""" - # Arrange - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = True - mock_account.is_dataset_editor = False - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag_service.delete_tag_binding.return_value = None - payload = {"tag_id": "tag_1", "target_id": "dataset_123"} - mock_service_api_ns.payload = payload - - from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi - - try: - # Act - with app.test_request_context("/", method="POST", json=payload): - api = DatasetTagUnbindingApi() - response = api.post("tenant_123") - - # Assert - assert response == ("", 204) - mock_tag_service.delete_tag_binding.assert_called_once_with( - {"tag_id": "tag_1", "target_id": "dataset_123", "type": "knowledge"} - ) - finally: - dataset_module.current_user = original_current_user - - -class TestDatasetTagsBindingStatusApi: - """Test suite for DatasetTagsBindingStatusApi endpoints.""" - - @pytest.fixture - def app(self): - """Create Flask test application.""" - from flask import Flask - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.service_api.dataset.dataset.TagService") - def test_get_dataset_tags_binding_status(self, mock_tag_service, app): - """Test retrieval of tags bound to a specific dataset.""" - # Arrange - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.current_tenant_id = "tenant_123" - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag = Mock() - mock_tag.id = "tag_1" - mock_tag.name = "Test Tag" - mock_tag_service.get_tags_by_target_id.return_value = [mock_tag] - - from controllers.service_api.dataset.dataset import DatasetTagsBindingStatusApi - - try: - # Act - with app.test_request_context("/", method="GET"): - api = DatasetTagsBindingStatusApi() - response, status_code = api.get("tenant_123", dataset_id="dataset_123") - - # Assert - assert status_code == 200 - assert response["data"] == [{"id": "tag_1", "name": "Test Tag"}] - assert response["total"] == 1 - mock_tag_service.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123") - finally: - dataset_module.current_user = original_current_user - - -class TestDocumentStatusApi: - """Test suite for DocumentStatusApi batch operations.""" - - @pytest.fixture - def app(self): - """Create Flask test application.""" - from flask import Flask - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.service_api.dataset.dataset.DatasetService") - @patch("controllers.service_api.dataset.dataset.DocumentService") - def test_batch_enable_documents(self, mock_doc_service, mock_dataset_service, app): - """Test batch enabling documents.""" - # Arrange - mock_dataset = Mock() - mock_dataset_service.get_dataset.return_value = mock_dataset - mock_doc_service.batch_update_document_status.return_value = None - - from controllers.service_api.dataset.dataset import DocumentStatusApi - - # Act - with app.test_request_context("/", method="PATCH", json={"document_ids": ["doc_1", "doc_2"]}): - api = DocumentStatusApi() - response, status_code = api.patch("tenant_123", "dataset_123", "enable") - - # Assert - assert status_code == 200 - assert response == {"result": "success"} - mock_doc_service.batch_update_document_status.assert_called_once() - - @patch("controllers.service_api.dataset.dataset.DatasetService") - def test_batch_update_dataset_not_found(self, mock_dataset_service, app): - """Test batch update when dataset not found.""" - # Arrange - mock_dataset_service.get_dataset.return_value = None - - from werkzeug.exceptions import NotFound - - from controllers.service_api.dataset.dataset import DocumentStatusApi - - # Act & Assert - with app.test_request_context("/", method="PATCH", json={"document_ids": ["doc_1"]}): - api = DocumentStatusApi() - with pytest.raises(NotFound) as exc_info: - api.patch("tenant_123", "dataset_123", "enable") - assert "Dataset not found" in str(exc_info.value) - - @patch("controllers.service_api.dataset.dataset.DatasetService") - @patch("controllers.service_api.dataset.dataset.DocumentService") - def test_batch_update_permission_error(self, mock_doc_service, mock_dataset_service, app): - """Test batch update with permission error.""" - # Arrange - mock_dataset = Mock() - mock_dataset_service.get_dataset.return_value = mock_dataset - from services.errors.account import NoPermissionError - - mock_dataset_service.check_dataset_permission.side_effect = NoPermissionError("No permission") - - from werkzeug.exceptions import Forbidden - - from controllers.service_api.dataset.dataset import DocumentStatusApi - - # Act & Assert - with app.test_request_context("/", method="PATCH", json={"document_ids": ["doc_1"]}): - api = DocumentStatusApi() - with pytest.raises(Forbidden): - api.patch("tenant_123", "dataset_123", "enable") - - @patch("controllers.service_api.dataset.dataset.DatasetService") - @patch("controllers.service_api.dataset.dataset.DocumentService") - def test_batch_update_invalid_action(self, mock_doc_service, mock_dataset_service, app): - """Test batch update with invalid action error.""" - # Arrange - mock_dataset = Mock() - mock_dataset_service.get_dataset.return_value = mock_dataset - mock_doc_service.batch_update_document_status.side_effect = ValueError("Invalid action") - - from controllers.service_api.dataset.dataset import DocumentStatusApi - from controllers.service_api.dataset.error import InvalidActionError - - # Act & Assert - with app.test_request_context("/", method="PATCH", json={"document_ids": ["doc_1"]}): - api = DocumentStatusApi() - with pytest.raises(InvalidActionError): - api.patch("tenant_123", "dataset_123", "invalid_action") - - """Test DatasetPermissionEnum values.""" - - def test_only_me_permission(self): - """Test ONLY_ME permission value.""" - assert DatasetPermissionEnum.ONLY_ME is not None - - def test_all_team_permission(self): - """Test ALL_TEAM permission value.""" - assert DatasetPermissionEnum.ALL_TEAM is not None - - def test_partial_team_permission(self): - """Test PARTIAL_TEAM permission value.""" - assert DatasetPermissionEnum.PARTIAL_TEAM is not None - - -class TestDatasetErrors: - """Test dataset-related error types.""" - - def test_dataset_in_use_error_can_be_raised(self): - """Test DatasetInUseError can be raised.""" - error = DatasetInUseError() - assert error is not None - - def test_dataset_name_duplicate_error_can_be_raised(self): - """Test DatasetNameDuplicateError can be raised.""" - error = DatasetNameDuplicateError() - assert error is not None - - def test_invalid_action_error_can_be_raised(self): - """Test InvalidActionError can be raised.""" - error = InvalidActionError("Invalid action") - assert error is not None - - -class TestDatasetService: - """Test DatasetService interface methods.""" - - def test_get_datasets_method_exists(self): - """Test DatasetService.get_datasets exists.""" - assert hasattr(DatasetService, "get_datasets") - - def test_get_dataset_method_exists(self): - """Test DatasetService.get_dataset exists.""" - assert hasattr(DatasetService, "get_dataset") - - def test_create_empty_dataset_method_exists(self): - """Test DatasetService.create_empty_dataset exists.""" - assert hasattr(DatasetService, "create_empty_dataset") - - def test_update_dataset_method_exists(self): - """Test DatasetService.update_dataset exists.""" - assert hasattr(DatasetService, "update_dataset") - - def test_delete_dataset_method_exists(self): - """Test DatasetService.delete_dataset exists.""" - assert hasattr(DatasetService, "delete_dataset") - - def test_check_dataset_permission_method_exists(self): - """Test DatasetService.check_dataset_permission exists.""" - assert hasattr(DatasetService, "check_dataset_permission") - - def test_check_dataset_model_setting_method_exists(self): - """Test DatasetService.check_dataset_model_setting exists.""" - assert hasattr(DatasetService, "check_dataset_model_setting") - - def test_check_embedding_model_setting_method_exists(self): - """Test DatasetService.check_embedding_model_setting exists.""" - assert hasattr(DatasetService, "check_embedding_model_setting") - - @patch.object(DatasetService, "get_datasets") - def test_get_datasets_returns_tuple(self, mock_get): - """Test get_datasets returns tuple of datasets and total.""" - mock_datasets = [Mock(), Mock()] - mock_get.return_value = (mock_datasets, 2) - - datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant_123", user=Mock()) - assert len(datasets) == 2 - assert total == 2 - - @patch.object(DatasetService, "get_dataset") - def test_get_dataset_returns_dataset(self, mock_get): - """Test get_dataset returns dataset object.""" - mock_dataset = Mock() - mock_dataset.id = str(uuid.uuid4()) - mock_dataset.name = "Test Dataset" - mock_get.return_value = mock_dataset - - result = DatasetService.get_dataset("dataset_id") - assert result.name == "Test Dataset" - - @patch.object(DatasetService, "get_dataset") - def test_get_dataset_returns_none_when_not_found(self, mock_get): - """Test get_dataset returns None when not found.""" - mock_get.return_value = None - - result = DatasetService.get_dataset("nonexistent_id") - assert result is None - - -class TestDatasetPermissionService: - """Test DatasetPermissionService interface.""" - - def test_check_permission_method_exists(self): - """Test DatasetPermissionService.check_permission exists.""" - assert hasattr(DatasetPermissionService, "check_permission") - - def test_get_dataset_partial_member_list_method_exists(self): - """Test DatasetPermissionService.get_dataset_partial_member_list exists.""" - assert hasattr(DatasetPermissionService, "get_dataset_partial_member_list") - - def test_update_partial_member_list_method_exists(self): - """Test DatasetPermissionService.update_partial_member_list exists.""" - assert hasattr(DatasetPermissionService, "update_partial_member_list") - - def test_clear_partial_member_list_method_exists(self): - """Test DatasetPermissionService.clear_partial_member_list exists.""" - assert hasattr(DatasetPermissionService, "clear_partial_member_list") - - -class TestDocumentService: - """Test DocumentService interface.""" - - def test_batch_update_document_status_method_exists(self): - """Test DocumentService.batch_update_document_status exists.""" - assert hasattr(DocumentService, "batch_update_document_status") - - -class TestTagService: - """Test TagService interface.""" - - def test_get_tags_method_exists(self): - """Test TagService.get_tags exists.""" - assert hasattr(TagService, "get_tags") - - def test_save_tags_method_exists(self): - """Test TagService.save_tags exists.""" - assert hasattr(TagService, "save_tags") - - def test_update_tags_method_exists(self): - """Test TagService.update_tags exists.""" - assert hasattr(TagService, "update_tags") - - def test_delete_tag_method_exists(self): - """Test TagService.delete_tag exists.""" - assert hasattr(TagService, "delete_tag") - - def test_save_tag_binding_method_exists(self): - """Test TagService.save_tag_binding exists.""" - assert hasattr(TagService, "save_tag_binding") - - def test_delete_tag_binding_method_exists(self): - """Test TagService.delete_tag_binding exists.""" - assert hasattr(TagService, "delete_tag_binding") - - def test_get_tags_by_target_id_method_exists(self): - """Test TagService.get_tags_by_target_id exists.""" - assert hasattr(TagService, "get_tags_by_target_id") - - def test_get_tag_binding_count_method_exists(self): - """Test TagService.get_tag_binding_count exists.""" - assert hasattr(TagService, "get_tag_binding_count") - - @patch.object(TagService, "get_tags") - def test_get_tags_returns_list(self, mock_get): - """Test get_tags returns list of tags.""" - mock_tags = [ - Mock(id="tag1", name="Tag One", type="knowledge"), - Mock(id="tag2", name="Tag Two", type="knowledge"), - ] - mock_get.return_value = mock_tags - - result = TagService.get_tags("knowledge", "tenant_123") - assert len(result) == 2 - - @patch.object(TagService, "save_tags") - def test_save_tags_returns_tag(self, mock_save): - """Test save_tags returns created tag.""" - mock_tag = Mock() - mock_tag.id = str(uuid.uuid4()) - mock_tag.name = "New Tag" - mock_tag.type = TagType.KNOWLEDGE - mock_save.return_value = mock_tag - - result = TagService.save_tags({"name": "New Tag", "type": "knowledge"}) - assert result.name == "New Tag" - - -class TestDocumentStatusAction: - """Test document status action values.""" - - def test_enable_action(self): - """Test enable action.""" - action = "enable" - assert action in ["enable", "disable", "archive", "un_archive"] - - def test_disable_action(self): - """Test disable action.""" - action = "disable" - assert action in ["enable", "disable", "archive", "un_archive"] - - def test_archive_action(self): - """Test archive action.""" - action = "archive" - assert action in ["enable", "disable", "archive", "un_archive"] - - def test_un_archive_action(self): - """Test un_archive action.""" - action = "un_archive" - assert action in ["enable", "disable", "archive", "un_archive"] - - -# ============================================================================= -# API Endpoint Tests -# -# ``DatasetListApi`` and ``DatasetApi`` inherit from ``DatasetApiResource`` -# whose ``method_decorators`` include ``validate_dataset_token``. -# -# Decorator strategy: -# - ``@cloud_edition_billing_rate_limit_check`` preserves ``__wrapped__`` -# → call via ``_unwrap(method)(self, …)``. -# - Methods without billing decorators → call directly; only patch ``db``, -# services, ``current_user``, and ``marshal``. -# ============================================================================= +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- def _unwrap(method): @@ -920,6 +235,15 @@ def _unwrap(method): return fn +@pytest.fixture +def app(flask_app_with_containers): + # Uses the full containerised app so that Flask config, extensions, and + # blueprint registrations match production. Most tests mock the service + # layer to isolate controller logic; a few (e.g. test_list_tags_from_db) + # exercise the real DB-backed path to validate end-to-end behaviour. + return flask_app_with_containers + + @pytest.fixture def mock_tenant(): tenant = Mock() @@ -938,12 +262,13 @@ def mock_dataset(): return dataset -class TestDatasetListApiGet: - """Test suite for DatasetListApi.get() endpoint. +# --------------------------------------------------------------------------- +# API endpoint tests — DatasetListApi +# --------------------------------------------------------------------------- - ``get`` has no billing decorators but calls ``current_user``, - ``DatasetService``, ``create_plugin_provider_manager``, and ``marshal``. - """ + +class TestDatasetListApiGet: + """Test suite for DatasetListApi.get() endpoint.""" @patch("controllers.service_api.dataset.dataset.marshal") @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @@ -958,7 +283,6 @@ class TestDatasetListApiGet: app, mock_tenant, ): - """Test successful dataset list retrieval.""" from controllers.service_api.dataset.dataset import DatasetListApi mock_current_user.__class__ = Account @@ -981,10 +305,7 @@ class TestDatasetListApiGet: class TestDatasetListApiPost: - """Test suite for DatasetListApi.post() endpoint. - - ``post`` is wrapped by ``@cloud_edition_billing_rate_limit_check``. - """ + """Test suite for DatasetListApi.post() endpoint.""" @patch("controllers.service_api.dataset.dataset.marshal") @patch("controllers.service_api.dataset.dataset.current_user") @@ -997,7 +318,6 @@ class TestDatasetListApiPost: app, mock_tenant, ): - """Test successful dataset creation.""" from controllers.service_api.dataset.dataset import DatasetListApi mock_current_user.__class__ = Account @@ -1024,7 +344,6 @@ class TestDatasetListApiPost: app, mock_tenant, ): - """Test DatasetNameDuplicateError when name already exists.""" from controllers.service_api.dataset.dataset import DatasetListApi mock_current_user.__class__ = Account @@ -1040,12 +359,13 @@ class TestDatasetListApiPost: _unwrap(api.post)(api, tenant_id=mock_tenant.id) -class TestDatasetApiGet: - """Test suite for DatasetApi.get() endpoint. +# --------------------------------------------------------------------------- +# API endpoint tests — DatasetApi +# --------------------------------------------------------------------------- - ``get`` has no billing decorators but calls ``DatasetService``, - ``create_plugin_provider_manager``, ``marshal``, and ``current_user``. - """ + +class TestDatasetApiGet: + """Test suite for DatasetApi.get() endpoint.""" @patch("controllers.service_api.dataset.dataset.DatasetPermissionService") @patch("controllers.service_api.dataset.dataset.marshal") @@ -1062,7 +382,6 @@ class TestDatasetApiGet: app, mock_dataset, ): - """Test successful dataset retrieval.""" from controllers.service_api.dataset.dataset import DatasetApi mock_dataset_svc.get_dataset.return_value = mock_dataset @@ -1092,7 +411,6 @@ class TestDatasetApiGet: @patch("controllers.service_api.dataset.dataset.DatasetService") def test_get_dataset_not_found(self, mock_dataset_svc, app, mock_dataset): - """Test 404 when dataset not found.""" from controllers.service_api.dataset.dataset import DatasetApi mock_dataset_svc.get_dataset.return_value = None @@ -1114,7 +432,6 @@ class TestDatasetApiGet: app, mock_dataset, ): - """Test 403 when user has no permission.""" from controllers.service_api.dataset.dataset import DatasetApi mock_dataset_svc.get_dataset.return_value = mock_dataset @@ -1130,10 +447,7 @@ class TestDatasetApiGet: class TestDatasetApiDelete: - """Test suite for DatasetApi.delete() endpoint. - - ``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check``. - """ + """Test suite for DatasetApi.delete() endpoint.""" @patch("controllers.service_api.dataset.dataset.DatasetPermissionService") @patch("controllers.service_api.dataset.dataset.current_user") @@ -1146,7 +460,6 @@ class TestDatasetApiDelete: app, mock_dataset, ): - """Test successful dataset deletion.""" from controllers.service_api.dataset.dataset import DatasetApi mock_dataset_svc.delete_dataset.return_value = True @@ -1169,7 +482,6 @@ class TestDatasetApiDelete: app, mock_dataset, ): - """Test 404 when dataset not found for deletion.""" from controllers.service_api.dataset.dataset import DatasetApi mock_dataset_svc.delete_dataset.return_value = False @@ -1191,7 +503,6 @@ class TestDatasetApiDelete: app, mock_dataset, ): - """Test DatasetInUseError when dataset is in use.""" from controllers.service_api.dataset.dataset import DatasetApi mock_dataset_svc.delete_dataset.side_effect = services.errors.dataset.DatasetInUseError() @@ -1205,12 +516,13 @@ class TestDatasetApiDelete: _unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id) -class TestDocumentStatusApiPatch: - """Test suite for DocumentStatusApi.patch() endpoint. +# --------------------------------------------------------------------------- +# API endpoint tests — DocumentStatusApi +# --------------------------------------------------------------------------- - ``patch`` has no billing decorators but calls ``DatasetService``, - ``DocumentService``, and ``current_user``. - """ + +class TestDocumentStatusApiPatch: + """Test suite for DocumentStatusApi.patch() endpoint.""" @patch("controllers.service_api.dataset.dataset.DocumentService") @patch("controllers.service_api.dataset.dataset.current_user") @@ -1224,7 +536,6 @@ class TestDocumentStatusApiPatch: mock_tenant, mock_dataset, ): - """Test successful batch document status update.""" from controllers.service_api.dataset.dataset import DocumentStatusApi mock_current_user.__class__ = Account @@ -1256,7 +567,6 @@ class TestDocumentStatusApiPatch: mock_tenant, mock_dataset, ): - """Test 404 when dataset not found.""" from controllers.service_api.dataset.dataset import DocumentStatusApi mock_dataset_svc.get_dataset.return_value = None @@ -1274,6 +584,39 @@ class TestDocumentStatusApiPatch: action="enable", ) + @patch("controllers.service_api.dataset.dataset.DocumentService") + @patch("controllers.service_api.dataset.dataset.current_user") + @patch("controllers.service_api.dataset.dataset.DatasetService") + def test_batch_update_status_permission_error( + self, + mock_dataset_svc, + mock_current_user, + mock_doc_svc, + app, + mock_tenant, + mock_dataset, + ): + from controllers.service_api.dataset.dataset import DocumentStatusApi + + mock_current_user.__class__ = Account + mock_dataset_svc.get_dataset.return_value = mock_dataset + mock_dataset_svc.check_dataset_permission.side_effect = services.errors.account.NoPermissionError( + "No permission" + ) + + with app.test_request_context( + f"/datasets/{mock_dataset.id}/documents/status/enable", + method="PATCH", + json={"document_ids": ["doc-1"]}, + ): + api = DocumentStatusApi() + with pytest.raises(Forbidden): + api.patch( + tenant_id=mock_tenant.id, + dataset_id=mock_dataset.id, + action="enable", + ) + @patch("controllers.service_api.dataset.dataset.DocumentService") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") @@ -1286,7 +629,6 @@ class TestDocumentStatusApiPatch: mock_tenant, mock_dataset, ): - """Test InvalidActionError when document is indexing.""" from controllers.service_api.dataset.dataset import DocumentStatusApi mock_current_user.__class__ = Account @@ -1320,7 +662,6 @@ class TestDocumentStatusApiPatch: mock_tenant, mock_dataset, ): - """Test InvalidActionError when ValueError raised.""" from controllers.service_api.dataset.dataset import DocumentStatusApi mock_current_user.__class__ = Account @@ -1343,6 +684,11 @@ class TestDocumentStatusApiPatch: ) +# --------------------------------------------------------------------------- +# API endpoint tests — Tags +# --------------------------------------------------------------------------- + + class TestDatasetTagsApiGet: """Test suite for DatasetTagsApi.get() endpoint.""" @@ -1354,7 +700,6 @@ class TestDatasetTagsApiGet: mock_tag_svc, app, ): - """Test successful tag list retrieval.""" from controllers.service_api.dataset.dataset import DatasetTagsApi mock_current_user.__class__ = Account @@ -1368,15 +713,49 @@ class TestDatasetTagsApiGet: assert status == 200 assert len(response) == 1 + mock_tag_svc.get_tags.assert_called_once_with("knowledge", "tenant-1") + + @pytest.mark.skip(reason="Production bug: DataSetTag.binding_count is str|None but DB COUNT() returns int") + @patch("controllers.service_api.dataset.dataset.current_user") + def test_list_tags_from_db( + self, + mock_current_user, + app, + db_session_with_containers: Session, + ): + """Integration test: creates real Tag rows and retrieves them + through the controller without mocking TagService.""" + from tests.test_containers_integration_tests.controllers.console.helpers import ( + create_console_account_and_tenant, + ) + + account, tenant = create_console_account_and_tenant(db_session_with_containers) + + tag = Tag( + name="Integration Tag", + type=TagType.KNOWLEDGE, + created_by=account.id, + tenant_id=tenant.id, + ) + db_session_with_containers.add(tag) + db_session_with_containers.commit() + + mock_current_user.__class__ = Account + mock_current_user.current_tenant_id = tenant.id + + from controllers.service_api.dataset.dataset import DatasetTagsApi + + with app.test_request_context("/datasets/tags", method="GET"): + api = DatasetTagsApi() + response, status = api.get(_=None) + + assert status == 200 + assert any(t["name"] == "Integration Tag" for t in response) class TestDatasetTagsApiPost: """Test suite for DatasetTagsApi.post() endpoint.""" - # BUG: dataset.py L512 passes ``binding_count=0`` (int) to - # ``DataSetTag.model_validate()``, but ``DataSetTag.binding_count`` - # is typed ``str | None`` (see fields/tag_fields.py L20). - # This causes a Pydantic ValidationError at runtime. @pytest.mark.skip(reason="Production bug: DataSetTag.binding_count is str|None but dataset.py passes int 0") @patch("controllers.service_api.dataset.dataset.TagService") @patch("controllers.service_api.dataset.dataset.current_user") @@ -1386,7 +765,6 @@ class TestDatasetTagsApiPost: mock_tag_svc, app, ): - """Test successful tag creation.""" from controllers.service_api.dataset.dataset import DatasetTagsApi mock_current_user.__class__ = Account @@ -1409,7 +787,6 @@ class TestDatasetTagsApiPost: @patch("controllers.service_api.dataset.dataset.current_user") def test_create_tag_forbidden(self, mock_current_user, app): - """Test 403 when user lacks edit permission.""" from controllers.service_api.dataset.dataset import DatasetTagsApi mock_current_user.__class__ = Account @@ -1426,6 +803,146 @@ class TestDatasetTagsApiPost: api.post(_=None) +class TestDatasetTagsApiPatch: + """Test suite for DatasetTagsApi.patch() endpoint.""" + + @pytest.mark.skip(reason="Production bug: DataSetTag.binding_count is str|None but dataset.py passes int 0") + @patch("controllers.service_api.dataset.dataset.TagService") + @patch("controllers.service_api.dataset.dataset.service_api_ns") + @patch("controllers.service_api.dataset.dataset.current_user") + def test_update_tag_success( + self, + mock_current_user, + mock_service_api_ns, + mock_tag_svc, + app, + ): + from controllers.service_api.dataset.dataset import DatasetTagsApi + + mock_current_user.__class__ = Account + mock_current_user.has_edit_permission = True + mock_current_user.is_dataset_editor = True + + mock_tag = SimpleNamespace(id="tag-1", name="Updated Tag", type="knowledge") + mock_tag_svc.update_tags.return_value = mock_tag + mock_tag_svc.get_tag_binding_count.return_value = 5 + mock_service_api_ns.payload = {"name": "Updated Tag", "tag_id": "tag-1"} + + with app.test_request_context( + "/datasets/tags", + method="PATCH", + json={"name": "Updated Tag", "tag_id": "tag-1"}, + ): + api = DatasetTagsApi() + response, status = api.patch(_=None) + + assert status == 200 + assert response["name"] == "Updated Tag" + mock_tag_svc.update_tags.assert_called_once_with({"name": "Updated Tag", "type": "knowledge"}, "tag-1") + + @patch("controllers.service_api.dataset.dataset.current_user") + def test_update_tag_forbidden(self, mock_current_user, app): + from controllers.service_api.dataset.dataset import DatasetTagsApi + + mock_current_user.__class__ = Account + mock_current_user.has_edit_permission = False + mock_current_user.is_dataset_editor = False + + with app.test_request_context( + "/datasets/tags", + method="PATCH", + json={"name": "Updated Tag", "tag_id": "tag-1"}, + ): + api = DatasetTagsApi() + with pytest.raises(Forbidden): + api.patch(_=None) + + +class TestDatasetTagsApiDelete: + """Test suite for DatasetTagsApi.delete() endpoint.""" + + @patch("controllers.service_api.dataset.dataset.TagService") + @patch("controllers.service_api.dataset.dataset.service_api_ns") + @patch("libs.login.current_user") + def test_delete_tag_success( + self, + mock_current_user, + mock_service_api_ns, + mock_tag_svc, + app, + ): + from controllers.service_api.dataset.dataset import DatasetTagsApi + + user_obj = Mock(spec=Account) + user_obj.has_edit_permission = True + mock_current_user.has_edit_permission = True + # Assign as plain lambda to avoid AsyncMock returning a coroutine + mock_current_user._get_current_object = lambda: user_obj + + mock_tag_svc.delete_tag.return_value = None + mock_service_api_ns.payload = {"tag_id": "tag-1"} + + with app.test_request_context( + "/datasets/tags", + method="DELETE", + json={"tag_id": "tag-1"}, + ): + api = DatasetTagsApi() + result = api.delete(_=None) + + assert result == ("", 204) + mock_tag_svc.delete_tag.assert_called_once_with("tag-1") + + @patch("libs.login.current_user") + def test_delete_tag_forbidden(self, mock_current_user, app): + from controllers.service_api.dataset.dataset import DatasetTagsApi + + user_obj = Mock(spec=Account) + user_obj.has_edit_permission = False + mock_current_user.has_edit_permission = False + # Assign as plain lambda to avoid AsyncMock returning a coroutine + mock_current_user._get_current_object = lambda: user_obj + + with app.test_request_context( + "/datasets/tags", + method="DELETE", + json={"tag_id": "tag-1"}, + ): + api = DatasetTagsApi() + with pytest.raises(Forbidden): + api.delete(_=None) + + +class TestDatasetTagsBindingStatusApi: + """Test suite for DatasetTagsBindingStatusApi endpoints.""" + + @patch("controllers.service_api.dataset.dataset.TagService") + @patch("controllers.service_api.dataset.dataset.current_user") + def test_get_dataset_tags_binding_status( + self, + mock_current_user, + mock_tag_svc, + app, + ): + from controllers.service_api.dataset.dataset import DatasetTagsBindingStatusApi + + mock_current_user.__class__ = Account + mock_current_user.current_tenant_id = "tenant_123" + mock_tag = Mock() + mock_tag.id = "tag_1" + mock_tag.name = "Test Tag" + mock_tag_svc.get_tags_by_target_id.return_value = [mock_tag] + + with app.test_request_context("/", method="GET"): + api = DatasetTagsBindingStatusApi() + response, status_code = api.get("tenant_123", dataset_id="dataset_123") + + assert status_code == 200 + assert response["data"] == [{"id": "tag_1", "name": "Test Tag"}] + assert response["total"] == 1 + mock_tag_svc.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123") + + class TestDatasetTagBindingApiPost: """Test suite for DatasetTagBindingApi.post() endpoint.""" @@ -1437,7 +954,6 @@ class TestDatasetTagBindingApiPost: mock_tag_svc, app, ): - """Test successful tag binding.""" from controllers.service_api.dataset.dataset import DatasetTagBindingApi mock_current_user.__class__ = Account @@ -1454,10 +970,14 @@ class TestDatasetTagBindingApiPost: result = api.post(_=None) assert result == ("", 204) + from services.tag_service import TagBindingCreatePayload + + mock_tag_svc.save_tag_binding.assert_called_once_with( + TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge") + ) @patch("controllers.service_api.dataset.dataset.current_user") def test_bind_tags_forbidden(self, mock_current_user, app): - """Test 403 when user lacks edit permission.""" from controllers.service_api.dataset.dataset import DatasetTagBindingApi mock_current_user.__class__ = Account @@ -1485,7 +1005,6 @@ class TestDatasetTagUnbindingApiPost: mock_tag_svc, app, ): - """Test successful tag unbinding.""" from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi mock_current_user.__class__ = Account @@ -1502,10 +1021,14 @@ class TestDatasetTagUnbindingApiPost: result = api.post(_=None) assert result == ("", 204) + from services.tag_service import TagBindingDeletePayload + + mock_tag_svc.delete_tag_binding.assert_called_once_with( + TagBindingDeletePayload(tag_id="tag-1", target_id="ds-1", type="knowledge") + ) @patch("controllers.service_api.dataset.dataset.current_user") def test_unbind_tag_forbidden(self, mock_current_user, app): - """Test 403 when user lacks edit permission.""" from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi mock_current_user.__class__ = Account diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py b/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py new file mode 100644 index 00000000000..4e884626a77 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py @@ -0,0 +1,110 @@ +""" +Testcontainers integration tests for Service API Site controller. +""" + +from __future__ import annotations + +import pytest +from flask import Flask +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.service_api.app.site import AppSiteApi +from models.account import Tenant, TenantStatus +from models.model import App, AppMode, Site + + +@pytest.fixture +def app(flask_app_with_containers) -> Flask: + return flask_app_with_containers + + +def _unwrap(method): + fn = method + while hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + return fn + + +def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant: + tenant = Tenant(name="service-api-site-tenant", status=status) + db_session.add(tenant) + db_session.commit() + return tenant + + +def _create_app(db_session: Session, tenant_id: str) -> App: + app_model = App( + tenant_id=tenant_id, + mode=AppMode.CHAT, + name="service-api-site-app", + enable_site=True, + enable_api=True, + status="normal", + ) + db_session.add(app_model) + db_session.commit() + return app_model + + +def _create_site(db_session: Session, app_id: str) -> Site: + site = Site( + app_id=app_id, + title="Service API Site", + icon_type="emoji", + icon="robot", + icon_background="#ffffff", + description="Service API test site", + default_language="en-US", + prompt_public=True, + show_workflow_steps=True, + customize_token_strategy="not_allow", + use_icon_as_answer_icon=False, + chat_color_theme="light", + chat_color_theme_inverted=False, + ) + db_session.add(site) + db_session.commit() + return site + + +class TestAppSiteApi: + def test_get_site_success(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + response = _unwrap(api.get)(api, app_model=app_model) + + assert response["title"] == "Service API Site" + assert response["icon"] == "robot" + assert response["description"] == "Service API test site" + + def test_get_site_not_found(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + with pytest.raises(Forbidden): + _unwrap(api.get)(api, app_model=app_model) + + def test_get_site_tenant_archived(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) + + archived_tenant = db_session_with_containers.get(Tenant, tenant.id) + assert archived_tenant is not None + archived_tenant.status = TenantStatus.ARCHIVE + db_session_with_containers.commit() + + app_model = db_session_with_containers.get(App, app_model.id) + assert app_model is not None + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + with pytest.raises(Forbidden): + _unwrap(api.get)(api, app_model=app_model) diff --git a/api/tests/unit_tests/controllers/web/test_site.py b/api/tests/test_containers_integration_tests/controllers/web/test_site.py similarity index 51% rename from api/tests/unit_tests/controllers/web/test_site.py rename to api/tests/test_containers_integration_tests/controllers/web/test_site.py index 6e9d754c436..9adb26ff3d2 100644 --- a/api/tests/unit_tests/controllers/web/test_site.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_site.py @@ -1,28 +1,48 @@ -"""Unit tests for controllers.web.site endpoints.""" +"""Testcontainers integration tests for controllers.web.site endpoints.""" from __future__ import annotations from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from flask import Flask +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from controllers.web.site import AppSiteApi, AppSiteInfo +from models import Tenant, TenantStatus +from models.model import App, AppMode, CustomizeTokenStrategy, Site -def _tenant(*, status: str = "normal") -> SimpleNamespace: - return SimpleNamespace( - id="tenant-1", - status=status, - plan="basic", - custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False}, +@pytest.fixture +def app(flask_app_with_containers) -> Flask: + return flask_app_with_containers + + +def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant: + tenant = Tenant(name="test-tenant", status=status) + db_session.add(tenant) + db_session.commit() + return tenant + + +def _create_app(db_session: Session, tenant_id: str, *, enable_site: bool = True) -> App: + app_model = App( + tenant_id=tenant_id, + mode=AppMode.CHAT, + name="test-app", + enable_site=enable_site, + enable_api=True, ) + db_session.add(app_model) + db_session.commit() + return app_model -def _site() -> SimpleNamespace: - return SimpleNamespace( +def _create_site(db_session: Session, app_id: str) -> Site: + site = Site( + app_id=app_id, title="Site", icon_type="emoji", icon="robot", @@ -31,77 +51,64 @@ def _site() -> SimpleNamespace: default_language="en", chat_color_theme="light", chat_color_theme_inverted=False, - copyright=None, - privacy_policy=None, - custom_disclaimer=None, + customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW, + code=f"code-{app_id[-6:]}", prompt_public=False, show_workflow_steps=True, use_icon_as_answer_icon=False, ) + db_session.add(site) + db_session.commit() + return site -# --------------------------------------------------------------------------- -# AppSiteApi -# --------------------------------------------------------------------------- class TestAppSiteApi: @patch("controllers.web.site.FeatureService.get_features") - @patch("controllers.web.site.db") - def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None: + def test_happy_path(self, mock_features, app: Flask, db_session_with_containers: Session) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - mock_features.return_value = SimpleNamespace(can_replace_logo=False) - site_obj = _site() - mock_db.session.scalar.return_value = site_obj - tenant = _tenant() - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) end_user = SimpleNamespace(id="eu-1") + mock_features.return_value = SimpleNamespace(can_replace_logo=False) with app.test_request_context("/site"): result = AppSiteApi().get(app_model, end_user) - # marshal_with serializes AppSiteInfo to a dict - assert result["app_id"] == "app-1" + assert result["app_id"] == app_model.id assert result["plan"] == "basic" assert result["enable_site"] is True - @patch("controllers.web.site.db") - def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + def test_missing_site_raises_forbidden(self, app: Flask, db_session_with_containers: Session) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - mock_db.session.scalar.return_value = None - tenant = _tenant() - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) end_user = SimpleNamespace(id="eu-1") with app.test_request_context("/site"): with pytest.raises(Forbidden): AppSiteApi().get(app_model, end_user) - @patch("controllers.web.site.db") - def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + @patch("controllers.web.site.FeatureService.get_features") + def test_archived_tenant_raises_forbidden( + self, mock_features, app: Flask, db_session_with_containers: Session + ) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - from models.account import TenantStatus - - mock_db.session.scalar.return_value = _site() - tenant = SimpleNamespace( - id="tenant-1", - status=TenantStatus.ARCHIVE, - plan="basic", - custom_config_dict={}, - ) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + tenant = _create_tenant(db_session_with_containers, status=TenantStatus.ARCHIVE) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) end_user = SimpleNamespace(id="eu-1") + mock_features.return_value = SimpleNamespace(can_replace_logo=False) with app.test_request_context("/site"): with pytest.raises(Forbidden): AppSiteApi().get(app_model, end_user) -# --------------------------------------------------------------------------- -# AppSiteInfo -# --------------------------------------------------------------------------- class TestAppSiteInfo: def test_basic_fields(self) -> None: - tenant = _tenant() - site_obj = _site() + tenant = SimpleNamespace(id="tenant-1", plan="basic", custom_config_dict={}) + site_obj = SimpleNamespace() info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False) assert info.app_id == "app-1" @@ -118,7 +125,7 @@ class TestAppSiteInfo: plan="pro", custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True}, ) - site_obj = _site() + site_obj = SimpleNamespace() info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True) assert info.can_replace_logo is True diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 04ad143103b..f14b2c0ae56 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -37,10 +37,8 @@ class TestForgotPasswordSendEmailApi: @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") - @patch("controllers.web.forgot_password.sessionmaker") def test_should_normalize_email_before_sending( self, - mock_session_cls, mock_extract_ip, mock_rate_limit, mock_get_account, @@ -50,19 +48,16 @@ class TestForgotPasswordSendEmailApi: mock_account = MagicMock() mock_get_account.return_value = mock_account mock_send_mail.return_value = "token-123" - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password", - method="POST", - json={"email": "User@Example.com", "language": "zh-Hans"}, - ): - response = ForgotPasswordSendEmailApi().post() + with app.test_request_context( + "/web/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") mock_extract_ip.assert_called_once() mock_rate_limit.assert_called_once_with("127.0.0.1") @@ -153,14 +148,14 @@ class TestForgotPasswordResetApi: @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") def test_should_fetch_account_with_fallback( self, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_get_account, mock_update_account, app, @@ -168,29 +163,27 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} mock_account = MagicMock() mock_get_account.return_value = mock_account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = mock_account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "token-123", - "new_password": "ValidPass123!", - "password_confirm": "ValidPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_update_account.assert_called_once() mock_revoke_token.assert_called_once_with("token-123") @patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") @patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -199,7 +192,7 @@ class TestForgotPasswordResetApi: mock_get_account, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_token_bytes, mock_hash_password, app, @@ -207,20 +200,18 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} account = MagicMock() mock_get_account.return_value = account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "reset-token", - "new_password": "StrongPass123!", - "password_confirm": "StrongPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "reset-token", + "new_password": "StrongPass123!", + "password_confirm": "StrongPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} mock_get_reset_data.assert_called_once_with("reset-token") diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 2b4c1b59abf..c342e8994b3 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -22,13 +22,6 @@ import uuid from time import time import pytest -from graphon.entities.pause_reason import SchedulingPause -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.entities.commands import GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from graphon.graph_events import GraphRunPausedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session @@ -40,6 +33,13 @@ from core.app.layers.pause_state_persist_layer import ( ) from core.workflow.system_variables import build_system_variables from extensions.ext_storage import storage +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import GraphRunPausedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel @@ -88,11 +88,11 @@ class TestPauseStatePersistenceLayerTestContainers: def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service): """Set up test data for each test method using TestContainers.""" # Create test tenant and account - from models.account import Tenant, TenantAccountJoin, TenantAccountRole + from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus tenant = Tenant( name="Test Tenant", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -101,7 +101,7 @@ class TestPauseStatePersistenceLayerTestContainers: email="test@example.com", name="Test User", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() @@ -557,11 +557,9 @@ class TestPauseStatePersistenceLayerTestContainers: self.session.refresh(self.test_workflow_run) assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING - pause_states = ( - self.session.query(WorkflowPauseModel) - .filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id) - .all() - ) + pause_states = self.session.scalars( + select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id) + ).all() assert len(pause_states) == 0 def test_layer_requires_initialization(self, db_session_with_containers): diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 13caad799eb..6524d6ce61f 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -4,12 +4,11 @@ from __future__ import annotations from uuid import uuid4 -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from sqlalchemy import Engine, select from sqlalchemy.orm import Session from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, @@ -18,7 +17,15 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction +from models.account import ( + Account, + AccountStatus, + Tenant, + TenantAccountJoin, + TenantAccountRole, + TenantStatus, +) from models.human_input import ( EmailExternalRecipientPayload, EmailMemberRecipientPayload, @@ -29,7 +36,7 @@ from models.human_input import ( def _create_tenant_with_members(session: Session, member_emails: list[str]) -> tuple[Tenant, list[Account]]: - tenant = Tenant(name="Test Tenant", status="normal") + tenant = Tenant(name="Test Tenant", status=TenantStatus.NORMAL) session.add(tenant) session.flush() @@ -39,7 +46,7 @@ def _create_tenant_with_members(session: Session, member_emails: list[str]) -> t email=email, name=f"Member {index}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) session.add(account) session.flush() diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 0a9b476afc5..5aed230cd4a 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -4,6 +4,17 @@ from datetime import timedelta from unittest.mock import MagicMock import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowType from graphon.graph import Graph from graphon.graph_engine import GraphEngine @@ -16,20 +27,9 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool -from sqlalchemy import delete, select -from sqlalchemy.orm import Session - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from models import Account -from models.account import Tenant, TenantAccountJoin, TenantAccountRole +from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import App, AppMode, IconType from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun @@ -101,8 +101,8 @@ def _build_graph( start_data = StartNodeData(title="start", variables=[]) start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, + node_id="start", + config=start_data, graph_init_params=params, graph_runtime_state=runtime_state, ) @@ -116,8 +116,8 @@ def _build_graph( ], ) human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, + node_id="human", + config=human_data, graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, @@ -130,8 +130,8 @@ def _build_graph( desc=None, ) end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, + node_id="end", + config=end_data, graph_init_params=params, graph_runtime_state=runtime_state, ) @@ -175,7 +175,7 @@ class TestHumanInputResumeNodeExecutionIntegration: def setup_test_data(self, db_session_with_containers: Session): tenant = Tenant( name="Test Tenant", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -184,7 +184,7 @@ class TestHumanInputResumeNodeExecutionIntegration: email="test@example.com", name="Test User", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index cc72dc1cf39..35e41035df6 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -4,13 +4,13 @@ from unittest.mock import patch from uuid import uuid4 import pytest -from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader +from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole @@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase): file_related_id = related_id return File( - id=str(uuid4()), # Generate new UUID for File.id + file_id=str(uuid4()), # Generate new UUID for File.id tenant_id=tenant_id, - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=transfer_method, related_id=file_related_id, remote_url=remote_url, diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index b745aed1417..2fd289dfbca 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -6,7 +6,6 @@ from decimal import Decimal from uuid import uuid4 from graphon.nodes.human_input.entities import FormDefinition, UserAction - from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin from models.enums import ConversationFromSource, InvokeFrom diff --git a/api/tests/test_containers_integration_tests/models/test_account.py b/api/tests/test_containers_integration_tests/models/test_account.py index 078dc0e8de1..6fd6716cbb9 100644 --- a/api/tests/test_containers_integration_tests/models/test_account.py +++ b/api/tests/test_containers_integration_tests/models/test_account.py @@ -1,79 +1,202 @@ -# import secrets +""" +Integration tests for Account and Tenant model methods that interact with the database. -# import pytest -# from sqlalchemy import select -# from sqlalchemy.orm import Session -# from sqlalchemy.orm.exc import DetachedInstanceError +Migrated from unit_tests/models/test_account_models.py, replacing +@patch("models.account.db") mock patches with real PostgreSQL operations. -# from libs.datetime_utils import naive_utc_now -# from models.account import Account, Tenant, TenantAccountJoin +Covers: +- Account.current_tenant setter (sets _current_tenant and role from TenantAccountJoin) +- Account.set_tenant_id (resolves tenant + role from real join row) +- Account.get_by_openid (AccountIntegrate lookup then Account fetch) +- Tenant.get_accounts (returns accounts linked via TenantAccountJoin) +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from sqlalchemy import delete +from sqlalchemy.orm import Session + +from models.account import Account, AccountIntegrate, Tenant, TenantAccountJoin, TenantAccountRole -# @pytest.fixture -# def session(db_session_with_containers): -# with Session(db_session_with_containers.get_bind()) as session: -# yield session +def _cleanup_tracked_rows(db_session: Session, tracked: list) -> None: + """Delete rows tracked during the test so committed state does not leak into the DB. + + Rolls back any pending (uncommitted) session state first, then issues DELETE + statements by primary key for each tracked entity (in reverse creation order) + and commits. This cleans up rows created via either flush() or commit(). + """ + db_session.rollback() + for entity in reversed(tracked): + db_session.execute(delete(type(entity)).where(type(entity).id == entity.id)) + db_session.commit() -# @pytest.fixture -# def account(session): -# account = Account( -# name="test account", -# email=f"test_{secrets.token_hex(8)}@example.com", -# ) -# session.add(account) -# session.commit() -# return account +def _build_tenant() -> Tenant: + return Tenant(name=f"Tenant {uuid4()}") -# @pytest.fixture -# def tenant(session): -# tenant = Tenant(name="test tenant") -# session.add(tenant) -# session.commit() -# return tenant +def _build_account(email_prefix: str = "account") -> Account: + return Account( + name=f"Account {uuid4()}", + email=f"{email_prefix}_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) -# @pytest.fixture -# def tenant_account_join(session, account, tenant): -# tenant_join = TenantAccountJoin(account_id=account.id, tenant_id=tenant.id) -# session.add(tenant_join) -# session.commit() -# yield tenant_join -# session.delete(tenant_join) -# session.commit() +class _DBTrackingTestBase: + """Base class providing a tracker list and shared row factories for account/tenant tests.""" + + _tracked: list + + @pytest.fixture(autouse=True) + def _setup_cleanup(self, db_session_with_containers: Session) -> Generator[None, None, None]: + self._tracked = [] + yield + _cleanup_tracked_rows(db_session_with_containers, self._tracked) + + def _create_tenant(self, db_session: Session) -> Tenant: + tenant = _build_tenant() + db_session.add(tenant) + db_session.flush() + self._tracked.append(tenant) + return tenant + + def _create_account(self, db_session: Session, email_prefix: str = "account") -> Account: + account = _build_account(email_prefix) + db_session.add(account) + db_session.flush() + self._tracked.append(account) + return account + + def _create_join( + self, db_session: Session, tenant_id: str, account_id: str, role: TenantAccountRole, current: bool = True + ) -> TenantAccountJoin: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id, role=role, current=current) + db_session.add(join) + db_session.flush() + self._tracked.append(join) + return join -# class TestAccountTenant: -# def test_set_current_tenant_should_reload_tenant( -# self, -# db_session_with_containers, -# account, -# tenant, -# tenant_account_join, -# ): -# with Session(db_session_with_containers.get_bind(), expire_on_commit=True) as session: -# scoped_tenant = session.scalars(select(Tenant).where(Tenant.id == tenant.id)).one() -# account.current_tenant = scoped_tenant -# scoped_tenant.created_at = naive_utc_now() -# # session.commit() +class TestAccountCurrentTenantSetter(_DBTrackingTestBase): + """Integration tests for Account.current_tenant property setter.""" -# # Ensure the tenant used in assignment is detached. -# with pytest.raises(DetachedInstanceError): -# _ = scoped_tenant.name + def test_current_tenant_property_returns_cached_tenant(self, db_session_with_containers: Session) -> None: + """current_tenant getter returns the in-memory _current_tenant without DB access.""" + account = self._create_account(db_session_with_containers) + tenant = self._create_tenant(db_session_with_containers) + account._current_tenant = tenant -# assert account._current_tenant.id == tenant.id -# assert account._current_tenant.id == tenant.id + assert account.current_tenant is tenant -# def test_set_tenant_id_should_load_tenant_as_not_expire( -# self, -# flask_app_with_containers, -# account, -# tenant, -# tenant_account_join, -# ): -# with flask_app_with_containers.test_request_context(): -# account.set_tenant_id(tenant.id) + def test_current_tenant_setter_sets_tenant_and_role_when_join_exists( + self, db_session_with_containers: Session + ) -> None: + """Setting current_tenant loads the join row and assigns role when relationship exists.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.OWNER) + db_session_with_containers.commit() -# assert account._current_tenant.id == tenant.id -# assert account._current_tenant.id == tenant.id + account.current_tenant = tenant + + assert account._current_tenant is not None + assert account._current_tenant.id == tenant.id + assert account.role == TenantAccountRole.OWNER + + def test_current_tenant_setter_sets_none_when_no_join_exists(self, db_session_with_containers: Session) -> None: + """Setting current_tenant results in _current_tenant=None when no join row exists.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + db_session_with_containers.commit() + + account.current_tenant = tenant + + assert account._current_tenant is None + + +class TestAccountSetTenantId(_DBTrackingTestBase): + """Integration tests for Account.set_tenant_id method.""" + + def test_set_tenant_id_sets_tenant_and_role_when_relationship_exists( + self, db_session_with_containers: Session + ) -> None: + """set_tenant_id loads the tenant and assigns role when a join row exists.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.ADMIN) + db_session_with_containers.commit() + + account.set_tenant_id(tenant.id) + + assert account._current_tenant is not None + assert account._current_tenant.id == tenant.id + assert account.role == TenantAccountRole.ADMIN + + def test_set_tenant_id_does_not_set_tenant_when_no_relationship_exists( + self, db_session_with_containers: Session + ) -> None: + """set_tenant_id does nothing when no join row matches the tenant.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + db_session_with_containers.commit() + + account.set_tenant_id(tenant.id) + + assert account._current_tenant is None + + +class TestAccountGetByOpenId(_DBTrackingTestBase): + """Integration tests for Account.get_by_openid class method.""" + + def test_get_by_openid_returns_account_when_integrate_exists(self, db_session_with_containers: Session) -> None: + """get_by_openid returns the Account when a matching AccountIntegrate row exists.""" + account = self._create_account(db_session_with_containers, email_prefix="openid") + provider = "google" + open_id = f"google_{uuid4()}" + + integrate = AccountIntegrate( + account_id=account.id, + provider=provider, + open_id=open_id, + encrypted_token="token", + ) + db_session_with_containers.add(integrate) + db_session_with_containers.flush() + self._tracked.append(integrate) + + result = Account.get_by_openid(provider, open_id) + + assert result is not None + assert result.id == account.id + + def test_get_by_openid_returns_none_when_no_integrate_exists(self, db_session_with_containers: Session) -> None: + """get_by_openid returns None when no AccountIntegrate row matches.""" + result = Account.get_by_openid("github", f"github_{uuid4()}") + + assert result is None + + +class TestTenantGetAccounts(_DBTrackingTestBase): + """Integration tests for Tenant.get_accounts method.""" + + def test_get_accounts_returns_linked_accounts(self, db_session_with_containers: Session) -> None: + """get_accounts returns all accounts linked to the tenant via TenantAccountJoin.""" + tenant = self._create_tenant(db_session_with_containers) + account1 = self._create_account(db_session_with_containers, email_prefix="tenant_member") + account2 = self._create_account(db_session_with_containers, email_prefix="tenant_member") + self._create_join(db_session_with_containers, tenant.id, account1.id, TenantAccountRole.OWNER, current=False) + self._create_join(db_session_with_containers, tenant.id, account2.id, TenantAccountRole.NORMAL, current=False) + + accounts = tenant.get_accounts() + + assert len(accounts) == 2 + account_ids = {a.id for a in accounts} + assert account1.id in account_ids + assert account2.id in account_ids diff --git a/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py b/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py new file mode 100644 index 00000000000..f10f519e250 --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py @@ -0,0 +1,149 @@ +""" +Integration tests for Conversation.inputs and Message.inputs tenant resolution. + +Migrated from unit_tests/models/test_model.py, replacing db.session.scalar monkeypatching +with a real App in PostgreSQL so the _resolve_app_tenant_id lookup executes against the DB. +""" + +from collections.abc import Generator +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from core.workflow.file_reference import build_file_reference +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod +from models.model import App, AppMode, Conversation, Message + + +def _build_local_file_mapping(record_id: str, *, tenant_id: str | None = None) -> dict: + mapping: dict = { + "dify_model_identity": FILE_MODEL_IDENTITY, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "reference": build_file_reference(record_id=record_id), + "type": "document", + "filename": "example.txt", + "extension": ".txt", + "mime_type": "text/plain", + "size": 1, + } + if tenant_id is not None: + mapping["tenant_id"] = tenant_id + return mapping + + +class TestConversationMessageInputsTenantResolution: + """Integration tests for Conversation/Message.inputs tenant resolution via real DB lookup.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def _create_app(self, db_session: Session) -> App: + tenant_id = str(uuid4()) + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + mode=AppMode.CHAT, + enable_site=False, + enable_api=True, + is_demo=False, + is_public=False, + is_universal=False, + created_by=str(uuid4()), + updated_by=str(uuid4()), + ) + db_session.add(app) + db_session.flush() + return app + + @pytest.mark.parametrize("owner_cls", [Conversation, Message]) + def test_inputs_resolves_tenant_via_db_for_local_file( + self, + db_session_with_containers: Session, + owner_cls: type, + ) -> None: + """Inputs resolves tenant_id from real App row when file mapping has no tenant_id.""" + app = self._create_app(db_session_with_containers) + build_calls: list[tuple[dict, str]] = [] + + def fake_build_from_mapping( + *, mapping, tenant_id, config=None, strict_type_validation=False, access_controller + ): + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping): + owner = owner_cls(app_id=app.id) + owner.inputs = {"file": _build_local_file_mapping("upload-1")} + + restored_inputs = owner.inputs + + # The tenant_id should come from the real App row in the DB + assert restored_inputs["file"] == {"tenant_id": app.tenant_id, "upload_file_id": "upload-1"} + assert len(build_calls) == 1 + assert build_calls[0][1] == app.tenant_id + + @pytest.mark.parametrize("owner_cls", [Conversation, Message]) + def test_inputs_uses_serialized_tenant_id_skipping_db_lookup( + self, + db_session_with_containers: Session, + owner_cls: type, + ) -> None: + """Inputs uses tenant_id from the file mapping payload without hitting the DB.""" + app = self._create_app(db_session_with_containers) + payload_tenant_id = "tenant-from-payload" + build_calls: list[tuple[dict, str]] = [] + + def fake_build_from_mapping( + *, mapping, tenant_id, config=None, strict_type_validation=False, access_controller + ): + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping): + owner = owner_cls(app_id=app.id) + owner.inputs = {"file": _build_local_file_mapping("upload-1", tenant_id=payload_tenant_id)} + + restored_inputs = owner.inputs + + assert restored_inputs["file"] == {"tenant_id": payload_tenant_id, "upload_file_id": "upload-1"} + assert len(build_calls) == 1 + assert build_calls[0][1] == payload_tenant_id + + @pytest.mark.parametrize("owner_cls", [Conversation, Message]) + def test_inputs_resolves_tenant_for_file_list( + self, + db_session_with_containers: Session, + owner_cls: type, + ) -> None: + """Inputs resolves tenant_id for a list of file mappings.""" + app = self._create_app(db_session_with_containers) + build_calls: list[tuple[dict, str]] = [] + + def fake_build_from_mapping( + *, mapping, tenant_id, config=None, strict_type_validation=False, access_controller + ): + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping): + owner = owner_cls(app_id=app.id) + owner.inputs = { + "files": [ + _build_local_file_mapping("upload-1"), + _build_local_file_mapping("upload-2"), + ] + } + + restored_inputs = owner.inputs + + assert len(build_calls) == 2 + assert all(call[1] == app.tenant_id for call in build_calls) + assert restored_inputs["files"] == [ + {"tenant_id": app.tenant_id, "upload_file_id": "upload-1"}, + {"tenant_id": app.tenant_id, "upload_file_id": "upload-2"}, + ] diff --git a/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py b/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py new file mode 100644 index 00000000000..6352f815df7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py @@ -0,0 +1,314 @@ +""" +Integration tests for Conversation.status_count and Site.generate_code model properties. + +Migrated from unit_tests/models/test_app_models.py TestConversationStatusCount and +test_site_generate_code, replacing db.session.scalars mocks with real PostgreSQL queries. +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from graphon.enums import WorkflowExecutionStatus +from models.enums import ConversationFromSource, InvokeFrom +from models.model import App, AppMode, Conversation, Message, Site +from models.workflow import Workflow, WorkflowRun, WorkflowRunTriggeredFrom, WorkflowType + + +class TestConversationStatusCount: + """Integration tests for Conversation.status_count property.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def _create_app(self, db_session: Session, tenant_id: str, created_by: str) -> App: + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + mode=AppMode.ADVANCED_CHAT, + enable_site=False, + enable_api=True, + is_demo=False, + is_public=False, + is_universal=False, + created_by=created_by, + updated_by=created_by, + ) + db_session.add(app) + db_session.flush() + return app + + def _create_conversation(self, db_session: Session, app: App) -> Conversation: + conversation = Conversation( + app_id=app.id, + mode=app.mode, + name=f"Conversation {uuid4()}", + summary="", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=InvokeFrom.WEB_APP, + from_source=ConversationFromSource.API, + dialogue_count=0, + is_deleted=False, + ) + conversation.inputs = {} + db_session.add(conversation) + db_session.flush() + return conversation + + def _create_workflow(self, db_session: Session, app: App, created_by: str) -> Workflow: + workflow = Workflow( + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.CHAT, + version="draft", + graph="{}", + created_by=created_by, + ) + workflow._features = "{}" + db_session.add(workflow) + db_session.flush() + return workflow + + def _create_workflow_run( + self, db_session: Session, app: App, workflow: Workflow, status: WorkflowExecutionStatus, created_by: str + ) -> WorkflowRun: + run = WorkflowRun( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type=WorkflowType.CHAT, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + version="draft", + status=status, + created_by_role="account", + created_by=created_by, + ) + db_session.add(run) + db_session.flush() + return run + + def _create_message( + self, db_session: Session, app: App, conversation: Conversation, workflow_run_id: str | None = None + ) -> Message: + message = Message( + app_id=app.id, + conversation_id=conversation.id, + _inputs={}, + query="Test query", + message={"role": "user", "content": "Test query"}, + answer="Test answer", + model_provider="openai", + model_id="gpt-4", + message_tokens=10, + message_unit_price=0, + answer_tokens=10, + answer_unit_price=0, + total_price=0, + currency="USD", + from_source=ConversationFromSource.API, + invoke_from=InvokeFrom.WEB_APP, + workflow_run_id=workflow_run_id, + ) + db_session.add(message) + db_session.flush() + return message + + def test_status_count_returns_none_when_no_messages(self, db_session_with_containers: Session) -> None: + """status_count returns None when conversation has no messages with workflow_run_id.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + + result = conversation.status_count + + assert result is None + + def test_status_count_returns_none_when_messages_have_no_workflow_run_id( + self, db_session_with_containers: Session + ) -> None: + """status_count returns None when messages exist but none have workflow_run_id.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=None) + + result = conversation.status_count + + assert result is None + + def test_status_count_counts_succeeded_workflow_run(self, db_session_with_containers: Session) -> None: + """status_count correctly counts succeeded workflow runs.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + run = self._create_workflow_run( + db_session_with_containers, app, workflow, WorkflowExecutionStatus.SUCCEEDED, created_by + ) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 1 + assert result["failed"] == 0 + assert result["partial_success"] == 0 + assert result["paused"] == 0 + + def test_status_count_counts_failed_workflow_run(self, db_session_with_containers: Session) -> None: + """status_count correctly counts failed workflow runs.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + run = self._create_workflow_run( + db_session_with_containers, app, workflow, WorkflowExecutionStatus.FAILED, created_by + ) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 0 + assert result["failed"] == 1 + assert result["partial_success"] == 0 + assert result["paused"] == 0 + + def test_status_count_counts_paused_workflow_run(self, db_session_with_containers: Session) -> None: + """status_count correctly counts paused workflow runs.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + run = self._create_workflow_run( + db_session_with_containers, app, workflow, WorkflowExecutionStatus.PAUSED, created_by + ) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 0 + assert result["failed"] == 0 + assert result["partial_success"] == 0 + assert result["paused"] == 1 + + def test_status_count_multiple_statuses(self, db_session_with_containers: Session) -> None: + """status_count counts multiple workflow runs with different statuses.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + + for status in [ + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + WorkflowExecutionStatus.PAUSED, + ]: + run = self._create_workflow_run(db_session_with_containers, app, workflow, status, created_by) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 1 + assert result["failed"] == 1 + assert result["partial_success"] == 1 + assert result["paused"] == 1 + + def test_status_count_filters_workflow_runs_by_app_id(self, db_session_with_containers: Session) -> None: + """status_count excludes workflow runs belonging to a different app.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + other_app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, other_app, created_by) + + # Workflow run belongs to other_app, not app + other_run = self._create_workflow_run( + db_session_with_containers, other_app, workflow, WorkflowExecutionStatus.SUCCEEDED, created_by + ) + # Message references that run but is in a conversation under app + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=other_run.id) + + result = conversation.status_count + + # The run should be excluded because app_id filter doesn't match + assert result is not None + assert result["success"] == 0 + + +class TestSiteGenerateCode: + """Integration tests for Site.generate_code static method.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def test_generate_code_returns_string_of_correct_length(self, db_session_with_containers: Session) -> None: + """Site.generate_code returns a code string of the requested length.""" + code = Site.generate_code(8) + + assert isinstance(code, str) + assert len(code) == 8 + + def test_generate_code_avoids_duplicates(self, db_session_with_containers: Session) -> None: + """Site.generate_code returns a code not already in use.""" + tenant_id = str(uuid4()) + app = App( + tenant_id=tenant_id, + name="Test App", + mode=AppMode.CHAT, + enable_site=True, + enable_api=False, + is_demo=False, + is_public=False, + is_universal=False, + created_by=str(uuid4()), + updated_by=str(uuid4()), + ) + db_session_with_containers.add(app) + db_session_with_containers.flush() + + site = Site( + app_id=app.id, + title="Test Site", + default_language="en-US", + customize_token_strategy="not_allow", + ) + # Set an explicit code so generate_code must avoid it + site.code = "AAAAAAAA" + db_session_with_containers.add(site) + db_session_with_containers.flush() + + code = Site.generate_code(8) + + assert isinstance(code, str) + assert len(code) == 8 + assert code != site.code diff --git a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py index 9cf96c1ca70..b325c97f7d1 100644 --- a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py +++ b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py @@ -5,11 +5,12 @@ from typing import Any, NamedTuple import pytest import sqlalchemy as sa from sqlalchemy import exc as sa_exc -from sqlalchemy import insert +from sqlalchemy import insert, select from sqlalchemy.engine import Connection, Engine from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy.sql.sqltypes import VARCHAR +from graphon.model_runtime.entities.model_entities import ModelType from models.types import EnumText _USER_TABLE = "enum_text_users" @@ -58,6 +59,13 @@ class _ColumnTest(_Base): long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False) +class _LegacyModelTypeRecord(_Base): + __tablename__ = "enum_text_legacy_model_type_test" + + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + model_type: Mapped[ModelType] = mapped_column(EnumText(enum_class=ModelType), nullable=False) + + def _first[T](it: Iterable[T]) -> T: ls = list(it) if not ls: @@ -129,12 +137,12 @@ class TestEnumText: session.commit() with Session(engine_with_containers) as session: - user = session.query(_User).where(_User.id == admin_user_id).first() + user = session.scalar(select(_User).where(_User.id == admin_user_id).limit(1)) assert user.user_type == _UserType.admin assert user.user_type_nullable is None with Session(engine_with_containers) as session: - user = session.query(_User).where(_User.id == normal_user_id).first() + user = session.scalar(select(_User).where(_User.id == normal_user_id).limit(1)) assert user.user_type == _UserType.normal assert user.user_type_nullable == _UserType.normal @@ -198,6 +206,26 @@ class TestEnumText: with pytest.raises(ValueError) as exc: with Session(engine_with_containers) as session: - _user = session.query(_User).where(_User.id == 1).first() + _user = session.scalar(select(_User).where(_User.id == 1).limit(1)) assert str(exc.value) == "'invalid' is not a valid _UserType" + + def test_select_legacy_model_type_values(self, engine_with_containers: Engine): + insertion_sql = """ + INSERT INTO enum_text_legacy_model_type_test (id, model_type) VALUES + (1, 'text-generation'), + (2, 'embeddings'), + (3, 'reranking'); + """ + with Session(engine_with_containers) as session: + session.execute(sa.text(insertion_sql)) + session.commit() + + with Session(engine_with_containers) as session: + records = session.scalars(select(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id)).all() + + assert [record.model_type for record in records] == [ + ModelType.LLM, + ModelType.TEXT_EMBEDDING, + ModelType.RERANK, + ] diff --git a/api/tests/test_containers_integration_tests/models/test_workflow_node_execution_model.py b/api/tests/test_containers_integration_tests/models/test_workflow_node_execution_model.py new file mode 100644 index 00000000000..14c2263110c --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_workflow_node_execution_model.py @@ -0,0 +1,170 @@ +""" +Integration tests for WorkflowNodeExecutionModel.created_by_account and .created_by_end_user. + +Migrated from unit_tests/models/test_workflow_trigger_log.py, replacing +monkeypatch.setattr(db.session, "scalar", ...) with real Account/EndUser rows +persisted in PostgreSQL so the db.session.get() call executes against the DB. +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from models.account import Account +from models.enums import CreatorUserRole +from models.model import App, AppMode, EndUser +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom + + +class TestWorkflowNodeExecutionModelCreatedBy: + """Integration tests for WorkflowNodeExecutionModel creator lookup properties.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def _create_account(self, db_session: Session) -> Account: + account = Account( + name="Test Account", + email=f"test_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session.add(account) + db_session.flush() + return account + + def _create_end_user(self, db_session: Session, tenant_id: str, app_id: str) -> EndUser: + end_user = EndUser( + tenant_id=tenant_id, + app_id=app_id, + type="service_api", + external_user_id=f"ext-{uuid4()}", + name="End User", + session_id=f"session-{uuid4()}", + ) + end_user.is_anonymous = False + db_session.add(end_user) + db_session.flush() + return end_user + + def _create_app(self, db_session: Session, tenant_id: str, created_by: str) -> App: + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + mode=AppMode.WORKFLOW, + enable_site=False, + enable_api=True, + is_demo=False, + is_public=False, + is_universal=False, + created_by=created_by, + updated_by=created_by, + ) + db_session.add(app) + db_session.flush() + return app + + def _make_execution( + self, tenant_id: str, app_id: str, created_by_role: str, created_by: str + ) -> WorkflowNodeExecutionModel: + return WorkflowNodeExecutionModel( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=str(uuid4()), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=None, + index=1, + predecessor_node_id=None, + node_execution_id=None, + node_id="n1", + node_type="start", + title="Start", + inputs=None, + process_data=None, + outputs=None, + status="succeeded", + error=None, + elapsed_time=0.0, + execution_metadata=None, + created_by_role=created_by_role, + created_by=created_by, + ) + + def test_created_by_account_returns_account_when_role_is_account(self, db_session_with_containers: Session) -> None: + """created_by_account returns the Account row when role is ACCOUNT.""" + account = self._create_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, str(uuid4()), account.id) + + execution = self._make_execution( + tenant_id=app.tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + ) + + result = execution.created_by_account + + assert result is not None + assert result.id == account.id + + def test_created_by_account_returns_none_when_role_is_end_user(self, db_session_with_containers: Session) -> None: + """created_by_account returns None when role is END_USER, even if an Account exists.""" + account = self._create_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, str(uuid4()), account.id) + + execution = self._make_execution( + tenant_id=app.tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.END_USER.value, + created_by=account.id, + ) + + result = execution.created_by_account + + assert result is None + + def test_created_by_end_user_returns_end_user_when_role_is_end_user( + self, db_session_with_containers: Session + ) -> None: + """created_by_end_user returns the EndUser row when role is END_USER.""" + account = self._create_account(db_session_with_containers) + tenant_id = str(uuid4()) + app = self._create_app(db_session_with_containers, tenant_id, account.id) + end_user = self._create_end_user(db_session_with_containers, tenant_id, app.id) + + execution = self._make_execution( + tenant_id=tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.END_USER.value, + created_by=end_user.id, + ) + + result = execution.created_by_end_user + + assert result is not None + assert result.id == end_user.id + + def test_created_by_end_user_returns_none_when_role_is_account(self, db_session_with_containers: Session) -> None: + """created_by_end_user returns None when role is ACCOUNT, even if an EndUser exists.""" + account = self._create_account(db_session_with_containers) + tenant_id = str(uuid4()) + app = self._create_app(db_session_with_containers, tenant_id, account.id) + end_user = self._create_end_user(db_session_with_containers, tenant_id, app.id) + + execution = self._make_execution( + tenant_id=tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=end_user.id, + ) + + result = execution.created_by_end_user + + assert result is None diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index a68b3a08c7a..641399c7f94 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -5,10 +5,10 @@ from __future__ import annotations from datetime import timedelta from uuid import uuid4 -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index d28cfda1598..aebe87839c5 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -8,15 +8,15 @@ from unittest.mock import Mock from uuid import uuid4 import pytest +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_storage import storage from graphon.entities import WorkflowExecution from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus -from sqlalchemy import Engine, delete, select -from sqlalchemy.orm import Session, sessionmaker - -from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( @@ -220,7 +220,6 @@ class TestDeleteRunsWithRelated: created_by=test_scope.user_id, ) pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=f"workflow-state-{uuid4()}.json", @@ -280,7 +279,6 @@ class TestCountRunsWithRelated: created_by=test_scope.user_id, ) pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=f"workflow-state-{uuid4()}.json", @@ -544,7 +542,6 @@ class TestPrivateWorkflowPauseEntity: status=WorkflowExecutionStatus.RUNNING, ) pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=f"workflow-state-{uuid4()}.json", @@ -574,7 +571,6 @@ class TestPrivateWorkflowPauseEntity: ) state_key = f"workflow-state-{uuid4()}.json" pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=state_key, @@ -606,7 +602,6 @@ class TestPrivateWorkflowPauseEntity: ) state_key = f"workflow-state-{uuid4()}.json" pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=state_key, @@ -672,7 +667,6 @@ class TestBuildHumanInputRequiredReason: status=WorkflowExecutionStatus.RUNNING, ) pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=f"workflow-state-{uuid4()}.json", diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index 7f44eb6ca37..54b7afc018a 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -12,11 +12,11 @@ from decimal import Decimal from uuid import uuid4 import pytest -from graphon.nodes.human_input.entities import FormDefinition, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker +from graphon.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ConversationFromSource, InvokeFrom @@ -271,7 +271,7 @@ def _create_recipient( def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery: - from core.workflow.human_input_compat import DeliveryMethodType + from core.workflow.human_input_adapter import DeliveryMethodType from models.human_input import ConsoleDeliveryPayload delivery = HumanInputDelivery( diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py new file mode 100644 index 00000000000..fa78f1c28b1 --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -0,0 +1,395 @@ +"""Testcontainers integration tests for SQLAlchemyWorkflowNodeExecutionRepository.""" + +from __future__ import annotations + +import json +from datetime import datetime +from decimal import Decimal +from uuid import uuid4 + +from sqlalchemy import Engine +from sqlalchemy.orm import Session, sessionmaker + +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.model_runtime.utils.encoders import jsonable_encoder +from models.account import Account, Tenant +from models.enums import CreatorUserRole +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom + + +def _create_account_with_tenant(session: Session) -> Account: + tenant = Tenant(name="Test Workspace") + session.add(tenant) + session.flush() + + account = Account(name="test", email=f"test-{uuid4()}@example.com") + session.add(account) + session.flush() + + account._current_tenant = tenant + return account + + +def _make_repo(session: Session, account: Account, app_id: str) -> SQLAlchemyWorkflowNodeExecutionRepository: + engine = session.get_bind() + assert isinstance(engine, Engine) + return SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=sessionmaker(bind=engine, expire_on_commit=False), + user=account, + app_id=app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + +def _create_node_execution_model( + session: Session, + *, + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_run_id: str, + index: int = 1, + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING, +) -> WorkflowNodeExecutionModel: + model = WorkflowNodeExecutionModel( + id=str(uuid4()), + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=workflow_run_id, + index=index, + predecessor_node_id=None, + node_execution_id=str(uuid4()), + node_id=f"node-{index}", + node_type=BuiltinNodeTypes.START, + title=f"Test Node {index}", + inputs='{"input_key": "input_value"}', + process_data='{"process_key": "process_value"}', + outputs='{"output_key": "output_value"}', + status=status, + error=None, + elapsed_time=1.5, + execution_metadata="{}", + created_at=datetime.now(), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + finished_at=None, + ) + session.add(model) + session.flush() + return model + + +class TestSave: + def test_save_new_record(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + index=1, + predecessor_node_id=None, + node_id="node-1", + node_type=BuiltinNodeTypes.START, + title="Test Node", + inputs={"input_key": "input_value"}, + process_data={"process_key": "process_value"}, + outputs={"result": "success"}, + status=WorkflowNodeExecutionStatus.RUNNING, + error=None, + elapsed_time=1.5, + metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}, + created_at=datetime.now(), + finished_at=None, + ) + + repo.save(execution) + + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session: + saved = verify_session.get(WorkflowNodeExecutionModel, execution.id) + assert saved is not None + assert saved.tenant_id == account.current_tenant_id + assert saved.app_id == app_id + assert saved.node_id == "node-1" + assert saved.status == WorkflowNodeExecutionStatus.RUNNING + + def test_save_updates_existing_record(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + repo = _make_repo(db_session_with_containers, account, str(uuid4())) + + execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + index=1, + predecessor_node_id=None, + node_id="node-1", + node_type=BuiltinNodeTypes.START, + title="Test Node", + inputs=None, + process_data=None, + outputs=None, + status=WorkflowNodeExecutionStatus.RUNNING, + error=None, + elapsed_time=0.0, + metadata=None, + created_at=datetime.now(), + finished_at=None, + ) + + repo.save(execution) + + execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + execution.elapsed_time = 2.5 + repo.save(execution) + + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session: + saved = verify_session.get(WorkflowNodeExecutionModel, execution.id) + assert saved is not None + assert saved.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert saved.elapsed_time == 2.5 + + +class TestGetByWorkflowExecution: + def test_returns_executions_ordered(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + tenant_id = account.current_tenant_id + app_id = str(uuid4()) + workflow_id = str(uuid4()) + workflow_run_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=1, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + ) + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=2, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + ) + db_session_with_containers.commit() + + order_config = OrderConfig(order_by=["index"], order_direction="desc") + result = repo.get_by_workflow_execution( + workflow_execution_id=workflow_run_id, + order_config=order_config, + ) + + assert len(result) == 2 + assert result[0].index == 2 + assert result[1].index == 1 + assert all(isinstance(r, WorkflowNodeExecution) for r in result) + + def test_excludes_paused_executions(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + tenant_id = account.current_tenant_id + app_id = str(uuid4()) + workflow_id = str(uuid4()) + workflow_run_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=1, + status=WorkflowNodeExecutionStatus.RUNNING, + ) + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=2, + status=WorkflowNodeExecutionStatus.PAUSED, + ) + db_session_with_containers.commit() + + result = repo.get_by_workflow_execution(workflow_execution_id=workflow_run_id) + + assert len(result) == 1 + assert result[0].index == 1 + + +class TestToDbModel: + def test_converts_domain_to_db_model(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + domain_model = WorkflowNodeExecution( + id="test-id", + workflow_id="test-workflow-id", + node_execution_id="test-node-execution-id", + workflow_execution_id="test-workflow-run-id", + index=1, + predecessor_node_id="test-predecessor-id", + node_id="test-node-id", + node_type=BuiltinNodeTypes.START, + title="Test Node", + inputs={"input_key": "input_value"}, + process_data={"process_key": "process_value"}, + outputs={"output_key": "output_value"}, + status=WorkflowNodeExecutionStatus.RUNNING, + error=None, + elapsed_time=1.5, + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: Decimal("0.0"), + }, + created_at=datetime.now(), + finished_at=None, + ) + + db_model = repo._to_db_model(domain_model) + + assert isinstance(db_model, WorkflowNodeExecutionModel) + assert db_model.id == domain_model.id + assert db_model.tenant_id == account.current_tenant_id + assert db_model.app_id == app_id + assert db_model.workflow_id == domain_model.workflow_id + assert db_model.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + assert db_model.workflow_run_id == domain_model.workflow_execution_id + assert db_model.index == domain_model.index + assert db_model.predecessor_node_id == domain_model.predecessor_node_id + assert db_model.node_execution_id == domain_model.node_execution_id + assert db_model.node_id == domain_model.node_id + assert db_model.node_type == domain_model.node_type + assert db_model.title == domain_model.title + assert db_model.inputs_dict == domain_model.inputs + assert db_model.process_data_dict == domain_model.process_data + assert db_model.outputs_dict == domain_model.outputs + assert db_model.execution_metadata_dict == jsonable_encoder(domain_model.metadata) + assert db_model.status == domain_model.status + assert db_model.error == domain_model.error + assert db_model.elapsed_time == domain_model.elapsed_time + assert db_model.created_at == domain_model.created_at + assert db_model.created_by_role == CreatorUserRole.ACCOUNT + assert db_model.created_by == account.id + assert db_model.finished_at == domain_model.finished_at + + +class TestToDomainModel: + def test_converts_db_to_domain_model(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + inputs_dict = {"input_key": "input_value"} + process_data_dict = {"process_key": "process_value"} + outputs_dict = {"output_key": "output_value"} + metadata_dict = {str(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS): 100} + now = datetime.now() + + db_model = WorkflowNodeExecutionModel() + db_model.id = "test-id" + db_model.tenant_id = account.current_tenant_id + db_model.app_id = app_id + db_model.workflow_id = "test-workflow-id" + db_model.triggered_from = "workflow-run" + db_model.workflow_run_id = "test-workflow-run-id" + db_model.index = 1 + db_model.predecessor_node_id = "test-predecessor-id" + db_model.node_execution_id = "test-node-execution-id" + db_model.node_id = "test-node-id" + db_model.node_type = BuiltinNodeTypes.START + db_model.title = "Test Node" + db_model.inputs = json.dumps(inputs_dict) + db_model.process_data = json.dumps(process_data_dict) + db_model.outputs = json.dumps(outputs_dict) + db_model.status = WorkflowNodeExecutionStatus.RUNNING + db_model.error = None + db_model.elapsed_time = 1.5 + db_model.execution_metadata = json.dumps(metadata_dict) + db_model.created_at = now + db_model.created_by_role = "account" + db_model.created_by = account.id + db_model.finished_at = None + + domain_model = repo._to_domain_model(db_model) + + assert isinstance(domain_model, WorkflowNodeExecution) + assert domain_model.id == "test-id" + assert domain_model.workflow_id == "test-workflow-id" + assert domain_model.workflow_execution_id == "test-workflow-run-id" + assert domain_model.index == 1 + assert domain_model.predecessor_node_id == "test-predecessor-id" + assert domain_model.node_execution_id == "test-node-execution-id" + assert domain_model.node_id == "test-node-id" + assert domain_model.node_type == BuiltinNodeTypes.START + assert domain_model.title == "Test Node" + assert domain_model.inputs == inputs_dict + assert domain_model.process_data == process_data_dict + assert domain_model.outputs == outputs_dict + assert domain_model.status == WorkflowNodeExecutionStatus.RUNNING + assert domain_model.error is None + assert domain_model.elapsed_time == 1.5 + assert domain_model.metadata == {WorkflowNodeExecutionMetadataKey(k): v for k, v in metadata_dict.items()} + assert domain_model.created_at == now + assert domain_model.finished_at is None + + def test_domain_model_without_offload_data(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + repo = _make_repo(db_session_with_containers, account, str(uuid4())) + + process_data = {"normal": "data"} + db_model = WorkflowNodeExecutionModel() + db_model.id = str(uuid4()) + db_model.tenant_id = account.current_tenant_id + db_model.app_id = str(uuid4()) + db_model.workflow_id = str(uuid4()) + db_model.triggered_from = "workflow-run" + db_model.workflow_run_id = None + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_execution_id = str(uuid4()) + db_model.node_id = "test-node-id" + db_model.node_type = "llm" + db_model.title = "Test Node" + db_model.inputs = None + db_model.process_data = json.dumps(process_data) + db_model.outputs = None + db_model.status = "succeeded" + db_model.error = None + db_model.elapsed_time = 1.5 + db_model.execution_metadata = "{}" + db_model.created_at = datetime.now() + db_model.created_by_role = "account" + db_model.created_by = account.id + db_model.finished_at = None + + domain_model = repo._to_domain_model(db_model) + + assert domain_model.process_data == process_data + assert domain_model.process_data_truncated is False + assert domain_model.get_truncated_process_data() is None diff --git a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py index c5e9201ee37..d6f0657380b 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py @@ -7,12 +7,12 @@ from datetime import timedelta from uuid import uuid4 import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy import exc as sa_exc from sqlalchemy.orm import Session, sessionmaker +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/tests/integration_tests/vdb/matrixone/__init__.py b/api/tests/test_containers_integration_tests/services/rag_pipeline/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/matrixone/__init__.py rename to api/tests/test_containers_integration_tests/services/rag_pipeline/__init__.py diff --git a/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py b/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py new file mode 100644 index 00000000000..8fc1809a467 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py @@ -0,0 +1,255 @@ +""" +Integration tests for RagPipelineService methods that interact with the database. + +Migrated from unit_tests/services/rag_pipeline/test_rag_pipeline_service.py, replacing +db.session.scalar/commit/delete mocker patches with real PostgreSQL operations. + +Covers: +- get_pipeline: Dataset and Pipeline lookups +- update_customized_pipeline_template: find + unique-name check + commit +- delete_customized_pipeline_template: find + delete + commit +""" + +from collections.abc import Generator +from types import SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate +from models.enums import DataSourceType +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity +from services.rag_pipeline.rag_pipeline import RagPipelineService + + +class TestRagPipelineServiceGetPipeline: + """Integration tests for RagPipelineService.get_pipeline.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _make_service(self, flask_app_with_containers) -> RagPipelineService: + with ( + patch( + "services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository", + return_value=None, + ), + patch( + "services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=None, + ), + ): + session_factory = sessionmaker(bind=flask_app_with_containers.extensions["sqlalchemy"].engine) + return RagPipelineService(session_maker=session_factory) + + def _create_pipeline(self, db_session: Session, tenant_id: str, created_by: str) -> Pipeline: + pipeline = Pipeline( + tenant_id=tenant_id, + name=f"Pipeline {uuid4()}", + description="", + created_by=created_by, + ) + db_session.add(pipeline) + db_session.flush() + return pipeline + + def _create_dataset( + self, db_session: Session, tenant_id: str, created_by: str, pipeline_id: str | None = None + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=f"Dataset {uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=created_by, + pipeline_id=pipeline_id, + ) + db_session.add(dataset) + db_session.flush() + return dataset + + def test_get_pipeline_raises_when_dataset_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """get_pipeline raises ValueError when dataset does not exist.""" + service = self._make_service(flask_app_with_containers) + + with pytest.raises(ValueError, match="Dataset not found"): + service.get_pipeline(tenant_id=str(uuid4()), dataset_id=str(uuid4())) + + def test_get_pipeline_raises_when_pipeline_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """get_pipeline raises ValueError when dataset exists but has no linked pipeline.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + dataset = self._create_dataset(db_session_with_containers, tenant_id, created_by, pipeline_id=None) + db_session_with_containers.flush() + + service = self._make_service(flask_app_with_containers) + + with pytest.raises(ValueError, match="(Dataset not found|Pipeline not found)"): + service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id) + + def test_get_pipeline_returns_pipeline_when_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """get_pipeline returns the Pipeline when both Dataset and Pipeline exist.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + pipeline = self._create_pipeline(db_session_with_containers, tenant_id, created_by) + dataset = self._create_dataset(db_session_with_containers, tenant_id, created_by, pipeline_id=pipeline.id) + db_session_with_containers.flush() + + service = self._make_service(flask_app_with_containers) + + result = service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id) + + assert result.id == pipeline.id + + +class TestUpdateCustomizedPipelineTemplate: + """Integration tests for RagPipelineService.update_customized_pipeline_template.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _create_template( + self, db_session: Session, tenant_id: str, created_by: str, name: str = "Template" + ) -> PipelineCustomizedTemplate: + template = PipelineCustomizedTemplate( + tenant_id=tenant_id, + name=name, + description="Original description", + chunk_structure="fixed_size", + icon={"type": "emoji", "value": "📄"}, + position=1, + yaml_content="{}", + install_count=0, + language="en-US", + created_by=created_by, + ) + db_session.add(template) + db_session.flush() + return template + + def test_update_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None: + """update_customized_pipeline_template updates name and description.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + template = self._create_template(db_session_with_containers, tenant_id, created_by) + db_session_with_containers.flush() + + fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + info = PipelineTemplateInfoEntity( + name="Updated Name", + description="Updated description", + icon_info=IconInfo(icon="🔥"), + ) + result = RagPipelineService.update_customized_pipeline_template(template.id, info) + + assert result.name == "Updated Name" + assert result.description == "Updated description" + + def test_update_template_raises_when_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """update_customized_pipeline_template raises ValueError when template doesn't exist.""" + fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4())) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + info = PipelineTemplateInfoEntity( + name="New Name", + description="desc", + icon_info=IconInfo(icon="📄"), + ) + with pytest.raises(ValueError, match="Customized pipeline template not found"): + RagPipelineService.update_customized_pipeline_template(str(uuid4()), info) + + def test_update_template_raises_on_duplicate_name( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """update_customized_pipeline_template raises ValueError when new name already exists.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + template1 = self._create_template(db_session_with_containers, tenant_id, created_by, name="Original") + self._create_template(db_session_with_containers, tenant_id, created_by, name="Duplicate") + db_session_with_containers.flush() + + fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + info = PipelineTemplateInfoEntity( + name="Duplicate", + description="desc", + icon_info=IconInfo(icon="📄"), + ) + with pytest.raises(ValueError, match="Template name is already exists"): + RagPipelineService.update_customized_pipeline_template(template1.id, info) + + +class TestDeleteCustomizedPipelineTemplate: + """Integration tests for RagPipelineService.delete_customized_pipeline_template.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _create_template(self, db_session: Session, tenant_id: str, created_by: str) -> PipelineCustomizedTemplate: + template = PipelineCustomizedTemplate( + tenant_id=tenant_id, + name=f"Template {uuid4()}", + description="Description", + chunk_structure="fixed_size", + icon={"type": "emoji", "value": "📄"}, + position=1, + yaml_content="{}", + install_count=0, + language="en-US", + created_by=created_by, + ) + db_session.add(template) + db_session.flush() + return template + + def test_delete_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None: + """delete_customized_pipeline_template removes the template from the DB.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + template = self._create_template(db_session_with_containers, tenant_id, created_by) + template_id = template.id + db_session_with_containers.flush() + + fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + RagPipelineService.delete_customized_pipeline_template(template_id) + + # Verify the record is deleted within the same context + from sqlalchemy import select + + from extensions.ext_database import db as ext_db + + remaining = ext_db.session.scalar( + select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id) + ) + assert remaining is None + + def test_delete_template_raises_when_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """delete_customized_pipeline_template raises ValueError when template doesn't exist.""" + fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4())) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + with pytest.raises(ValueError, match="Customized pipeline template not found"): + RagPipelineService.delete_customized_pipeline_template(str(uuid4())) diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index cc9596d15fb..9a53ff087c5 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -9,7 +9,7 @@ from werkzeug.exceptions import Unauthorized from configs import dify_config from controllers.console.error import AccountNotFound, NotAllowedCreateWorkspace -from models import AccountStatus, TenantAccountJoin +from models import AccountStatus, TenantAccountJoin, TenantStatus from services.account_service import AccountService, RegisterService, TenantService, TokenPair from services.errors.account import ( AccountAlreadyInTenantError, @@ -2851,7 +2851,7 @@ class TestRegisterService: interface_language="en-US", password=existing_pending_member_password, ) - existing_account.status = "pending" + existing_account.status = AccountStatus.PENDING db_session_with_containers.commit() @@ -2941,7 +2941,7 @@ class TestRegisterService: interface_language="en-US", password=already_in_tenant_password, ) - existing_account.status = "active" + existing_account.status = AccountStatus.ACTIVE db_session_with_containers.commit() @@ -3331,7 +3331,7 @@ class TestRegisterService: TenantService.create_tenant_member(tenant, account, role="normal") # Change tenant status to non-normal - tenant.status = "archive" + tenant.status = TenantStatus.ARCHIVE db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 4f3c0e42000..00a2f9a59f4 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -842,7 +842,6 @@ class TestAgentService: conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) from graphon.file import FileTransferMethod, FileType - from models.enums import CreatorUserRole # Add files to message diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 33955d5d84b..77ce28b999f 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -1,56 +1,104 @@ +from __future__ import annotations + +import base64 import json +from types import SimpleNamespace from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest import yaml from faker import Faker -from models.model import App, AppModelConfig +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) +from extensions.ext_redis import redis_client +from graphon.enums import BuiltinNodeTypes +from models import Account, AppMode +from models.model import AppModelConfig, IconType +from services import app_dsl_service from services.account_service import AccountService, TenantService -from services.app_dsl_service import AppDslService, ImportMode, ImportStatus +from services.app_dsl_service import ( + CHECK_DEPENDENCIES_REDIS_KEY_PREFIX, + CURRENT_DSL_VERSION, + DSL_MAX_SIZE, + IMPORT_INFO_REDIS_EXPIRY, + IMPORT_INFO_REDIS_KEY_PREFIX, + AppDslService, + CheckDependenciesPendingData, + ImportMode, + ImportStatus, + PendingData, + _check_version_compatibility, +) from services.app_service import AppService from tests.test_containers_integration_tests.helpers import generate_valid_password +_DEFAULT_TENANT_ID = "00000000-0000-0000-0000-000000000001" +_DEFAULT_ACCOUNT_ID = "00000000-0000-0000-0000-000000000002" + + +def _account_mock(*, tenant_id: str = _DEFAULT_TENANT_ID, account_id: str = _DEFAULT_ACCOUNT_ID) -> MagicMock: + account = MagicMock(spec=Account) + account.current_tenant_id = tenant_id + account.id = account_id + return account + + +def _yaml_dump(data: dict) -> str: + return yaml.safe_dump(data, allow_unicode=True) + + +def _workflow_yaml(*, version: str = CURRENT_DSL_VERSION) -> str: + return _yaml_dump( + { + "version": version, + "kind": "app", + "app": {"name": "My App", "mode": AppMode.WORKFLOW.value}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + } + ) + + +def _pending_yaml_content(version: str = "99.0.0") -> bytes: + return (f'version: "{version}"\nkind: app\napp:\n name: Loop Test\n mode: workflow\n').encode() + class TestAppDslService: """Integration tests for AppDslService using testcontainers.""" + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + @pytest.fixture def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( patch("services.app_dsl_service.WorkflowService") as mock_workflow_service, patch("services.app_dsl_service.DependenciesAnalysisService") as mock_dependencies_service, - patch("services.app_dsl_service.WorkflowDraftVariableService") as mock_draft_variable_service, - patch("services.app_dsl_service.ssrf_proxy") as mock_ssrf_proxy, - patch("services.app_dsl_service.redis_client") as mock_redis_client, patch("services.app_dsl_service.app_was_created") as mock_app_was_created, - patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated, patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, ): - # Setup default mock returns mock_workflow_service.return_value.get_draft_workflow.return_value = None mock_workflow_service.return_value.sync_draft_workflow.return_value = MagicMock() mock_dependencies_service.generate_latest_dependencies.return_value = [] mock_dependencies_service.get_leaked_dependencies.return_value = [] mock_dependencies_service.generate_dependencies.return_value = [] - mock_draft_variable_service.return_value.delete_workflow_variables.return_value = None - mock_ssrf_proxy.get.return_value.content = b"test content" - mock_ssrf_proxy.get.return_value.raise_for_status.return_value = None - mock_redis_client.setex.return_value = None - mock_redis_client.get.return_value = None - mock_redis_client.delete.return_value = None mock_app_was_created.send.return_value = None - mock_app_model_config_was_updated.send.return_value = None - # Mock ModelManager for app service mock_model_instance = mock_model_manager.return_value mock_model_instance.get_default_model_instance.return_value = None - mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + mock_model_instance.get_default_provider_model_name.return_value = ( + "openai", + "gpt-3.5-turbo", + ) - # Mock FeatureService and EnterpriseService mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None @@ -58,34 +106,16 @@ class TestAppDslService: yield { "workflow_service": mock_workflow_service, "dependencies_service": mock_dependencies_service, - "draft_variable_service": mock_draft_variable_service, - "ssrf_proxy": mock_ssrf_proxy, - "redis_client": mock_redis_client, "app_was_created": mock_app_was_created, - "app_model_config_was_updated": mock_app_model_config_was_updated, "model_manager": mock_model_manager, "feature_service": mock_feature_service, "enterprise_service": mock_enterprise_service, } def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): - """ - Helper method to create a test app and account for testing. - - Args: - db_session_with_containers: Database session from testcontainers infrastructure - mock_external_service_dependencies: Mock dependencies - - Returns: - tuple: (app, account) - Created app and account instances - """ fake = Faker() - - # Setup mocks for account creation with patch("services.account_service.FeatureService") as mock_account_feature_service: mock_account_feature_service.get_system_features.return_value.is_allow_register = True - - # Create account and tenant first account = AccountService.create_account( email=fake.email(), name=fake.name(), @@ -94,8 +124,6 @@ class TestAppDslService: ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant - - # Setup app creation arguments app_args = { "name": fake.company(), "description": fake.text(max_nb_chars=100), @@ -106,17 +134,11 @@ class TestAppDslService: "api_rph": 100, "api_rpm": 10, } - - # Create app app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) - return app, account - def _create_simple_yaml_content(self, app_name="Test App", app_mode="chat"): - """ - Helper method to create simple YAML content for testing. - """ + def _create_simple_yaml_content(self, app_name: str = "Test App", app_mode: str = "chat") -> str: yaml_data = { "version": "0.3.0", "kind": "app", @@ -145,88 +167,739 @@ class TestAppDslService: } return yaml.dump(yaml_data, allow_unicode=True) - def test_import_app_missing_yaml_content(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test app import with missing YAML content. - """ - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + # ── Version Compatibility ───────────────────────────────────────── - # Import app without YAML content - dsl_service = AppDslService(db_session_with_containers) - result = dsl_service.import_app( - account=account, - import_mode=ImportMode.YAML_CONTENT, - name="Missing Content App", - ) + def test_check_version_compatibility_invalid_version_returns_failed(self): + assert _check_version_compatibility("not-a-version") == ImportStatus.FAILED - # Verify import failed - assert result.status == ImportStatus.FAILED - assert result.app_id is None - assert "yaml_content is required" in result.error - assert result.imported_dsl_version == "" + def test_check_version_compatibility_newer_version_returns_pending(self): + assert _check_version_compatibility("99.0.0") == ImportStatus.PENDING - # Verify no app was created in database - apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count() - assert apps_count == 1 # Only the original test app + def test_check_version_compatibility_major_older_returns_pending(self, monkeypatch): + monkeypatch.setattr(app_dsl_service, "CURRENT_DSL_VERSION", "1.0.0") + assert _check_version_compatibility("0.9.9") == ImportStatus.PENDING - def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test app import with missing YAML URL. - """ - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + def test_check_version_compatibility_minor_older_returns_completed_with_warnings( + self, + ): + assert _check_version_compatibility("0.5.0") == ImportStatus.COMPLETED_WITH_WARNINGS - # Import app without YAML URL - dsl_service = AppDslService(db_session_with_containers) - result = dsl_service.import_app( - account=account, - import_mode=ImportMode.YAML_URL, - name="Missing URL App", - ) + def test_check_version_compatibility_equal_returns_completed(self): + assert _check_version_compatibility(CURRENT_DSL_VERSION) == ImportStatus.COMPLETED - # Verify import failed - assert result.status == ImportStatus.FAILED - assert result.app_id is None - assert "yaml_url is required" in result.error - assert result.imported_dsl_version == "" + # ── Import: Validation ──────────────────────────────────────────── - # Verify no app was created in database - apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count() - assert apps_count == 1 # Only the original test app - - def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test app import with invalid import mode. - """ - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - - # Create YAML content - yaml_content = self._create_simple_yaml_content(fake.company(), "chat") - - # Import app with invalid mode should raise ValueError - dsl_service = AppDslService(db_session_with_containers) - with pytest.raises(ValueError, match="Invalid import_mode: invalid-mode"): - dsl_service.import_app( - account=account, + def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Invalid import_mode"): + service.import_app( + account=_account_mock(), import_mode="invalid-mode", - yaml_content=yaml_content, - name="Invalid Mode App", + yaml_content="version: '0.1.0'", ) - # Verify no app was created in database - apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count() - assert apps_count == 1 # Only the original test app + def test_import_app_missing_yaml_content(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=None, + ) + assert result.status == ImportStatus.FAILED + assert "yaml_content is required" in result.error - def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test successful DSL export for chat app. - """ - fake = Faker() + def test_import_app_missing_yaml_url(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url=None, + ) + assert result.status == ImportStatus.FAILED + assert "yaml_url is required" in result.error + + def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content="[]", + ) + assert result.status == ImportStatus.FAILED + assert "content must be a mapping" in result.error + + def test_import_app_version_not_str_returns_failed(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + yaml_content = _yaml_dump({"version": 1, "kind": "app", "app": {"name": "x", "mode": "workflow"}}) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=yaml_content, + ) + assert result.status == ImportStatus.FAILED + assert "Invalid version type" in result.error + + def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump({"version": "0.6.0", "kind": "app"}), + ) + assert result.status == ImportStatus.FAILED + assert "Missing app data" in result.error + + def test_import_app_yaml_error_returns_failed(self, db_session_with_containers, monkeypatch): + def bad_safe_load(_content: str): + raise yaml.YAMLError("bad") + + monkeypatch.setattr(app_dsl_service.yaml, "safe_load", bad_safe_load) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content="x: y", + ) + assert result.status == ImportStatus.FAILED + assert result.error.startswith("Invalid YAML format:") + + def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers, monkeypatch): + monkeypatch.setattr( + AppDslService, + "_create_or_update_app", + lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("oops")), + ) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + ) + assert result.status == ImportStatus.FAILED + assert result.error == "oops" + + # ── Import: YAML URL ────────────────────────────────────────────── + + def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers, monkeypatch): + monkeypatch.setattr( + app_dsl_service.ssrf_proxy, + "get", + lambda _url, **_kw: (_ for _ in ()).throw(RuntimeError("boom")), + ) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url="https://example.com/a.yml", + ) + assert result.status == ImportStatus.FAILED + assert "Error fetching YAML from URL: boom" in result.error + + def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers, monkeypatch): + response = MagicMock() + response.content = b"" + response.raise_for_status.return_value = None + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", lambda _url, **_kw: response) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url="https://example.com/a.yml", + ) + assert result.status == ImportStatus.FAILED + assert "Empty content" in result.error + + def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers, monkeypatch): + response = MagicMock() + response.content = b"x" * (DSL_MAX_SIZE + 1) + response.raise_for_status.return_value = None + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", lambda _url, **_kw: response) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url="https://example.com/a.yml", + ) + assert result.status == ImportStatus.FAILED + assert "File size exceeds" in result.error + + def test_import_app_yaml_url_user_attachments_keeps_original_url(self, db_session_with_containers, monkeypatch): + yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml" + yaml_bytes = _pending_yaml_content() + + requested_urls: list[str] = [] + + def fake_get(url: str, **kwargs): + requested_urls.append(url) + response = MagicMock() + response.content = yaml_bytes + response.raise_for_status.return_value = None + return response + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url=yaml_url, + ) + + assert result.status == ImportStatus.PENDING + assert result.imported_dsl_version == "99.0.0" + assert requested_urls == [yaml_url] + + def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers, monkeypatch): + yaml_url = "https://github.com/acme/repo/blob/main/app.yml" + raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml" + yaml_bytes = _pending_yaml_content() + + requested_urls: list[str] = [] + + def fake_get(url: str, **kwargs): + requested_urls.append(url) + assert url == raw_url + response = MagicMock() + response.content = yaml_bytes + response.raise_for_status.return_value = None + return response + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url=yaml_url, + ) + + assert result.status == ImportStatus.PENDING + assert requested_urls == [raw_url] + + # ── Import: App ID checks ──────────────────────────────────────── + + def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + app_id=str(uuid4()), + ) + assert result.status == ImportStatus.FAILED + assert result.error == "App not found" + + def test_import_app_overwrite_only_allows_workflow_and_advanced_chat( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + assert app.mode == "chat" + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + app_id=app.id, + ) + assert result.status == ImportStatus.FAILED + assert "Only workflow or advanced chat apps" in result.error + + # ── Import: Flow ────────────────────────────────────────────────── + + def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(version="99.0.0"), + name="n", + description="d", + icon_type="emoji", + icon="i", + icon_background="#000000", + ) + assert result.status == ImportStatus.PENDING + assert result.imported_dsl_version == "99.0.0" + + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{result.id}" + stored = redis_client.get(redis_key) + assert stored is not None + + def test_import_app_completed_uses_declared_dependencies( + self, db_session_with_containers, mock_external_service_dependencies + ): + _, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + dependencies_payload = [ + { + "type": "package", + "value": { + "plugin_unique_identifier": "langgenius/google", + "version": "1.0.0", + }, + } + ] + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump( + { + "version": CURRENT_DSL_VERSION, + "kind": "app", + "app": {"name": "My App", "mode": AppMode.WORKFLOW.value}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + "dependencies": dependencies_payload, + } + ), + ) + + assert result.status == ImportStatus.COMPLETED + assert result.app_id is not None + + @pytest.mark.parametrize("has_workflow", [True, False]) + def test_import_app_legacy_versions_extract_dependencies( + self, db_session_with_containers, monkeypatch, has_workflow: bool + ): + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_workflow_graph", + lambda *_args, **_kwargs: ["from-workflow"], + ) + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_model_config", + lambda *_args, **_kwargs: ["from-model-config"], + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_latest_dependencies", + lambda deps: [SimpleNamespace(model_dump=lambda: {"dep": deps[0]})], + ) + + created_app = SimpleNamespace( + id=str(uuid4()), + mode=AppMode.WORKFLOW.value, + tenant_id=_DEFAULT_TENANT_ID, + ) + monkeypatch.setattr( + AppDslService, + "_create_or_update_app", + lambda *_args, **_kwargs: created_app, + ) + + draft_var_service = MagicMock() + monkeypatch.setattr( + app_dsl_service, + "WorkflowDraftVariableService", + lambda *args, **kwargs: draft_var_service, + ) + + data: dict = { + "version": "0.1.5", + "kind": "app", + "app": {"name": "Legacy", "mode": AppMode.WORKFLOW.value}, + } + if has_workflow: + data["workflow"] = {"graph": {"nodes": []}, "features": {}} + else: + data["model_config"] = {"model": {"provider": "openai"}} + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump(data), + ) + assert result.status == ImportStatus.COMPLETED_WITH_WARNINGS + draft_var_service.delete_app_workflow_variables.assert_called_once_with(app_id=created_app.id) + + # ── Confirm Import ──────────────────────────────────────────────── + + def test_confirm_import_expired_returns_failed(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.confirm_import(import_id=str(uuid4()), account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert "expired" in result.error + + def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers, monkeypatch): + import_id = str(uuid4()) + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + + pending = PendingData( + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + name="name", + description="desc", + icon_type="emoji", + icon="🤖", + icon_background="#fff", + app_id=None, + ) + redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, pending.model_dump_json()) + + created_app = SimpleNamespace( + id=str(uuid4()), + mode=AppMode.WORKFLOW.value, + tenant_id=_DEFAULT_TENANT_ID, + ) + monkeypatch.setattr( + AppDslService, + "_create_or_update_app", + lambda *_args, **_kwargs: created_app, + ) + + service = AppDslService(db_session_with_containers) + result = service.confirm_import(import_id=import_id, account=_account_mock()) + assert result.status == ImportStatus.COMPLETED + assert result.app_id == created_app.id + assert redis_client.get(redis_key) is None + + def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers): + import_id = str(uuid4()) + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "123") + + service = AppDslService(db_session_with_containers) + result = service.confirm_import(import_id=import_id, account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert "validation error" in result.error + + def test_confirm_import_exception_returns_failed(self, db_session_with_containers): + import_id = str(uuid4()) + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "not-valid-json") + + service = AppDslService(db_session_with_containers) + result = service.confirm_import(import_id=import_id, account=_account_mock()) + assert result.status == ImportStatus.FAILED + + # ── Check Dependencies ──────────────────────────────────────────── + + def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + app_model = SimpleNamespace(id=str(uuid4()), tenant_id=_DEFAULT_TENANT_ID) + result = service.check_dependencies(app_model=app_model) + assert result.leaked_dependencies == [] + + def test_check_dependencies_calls_analysis_service(self, db_session_with_containers, monkeypatch): + app_id = str(uuid4()) + pending = CheckDependenciesPendingData(dependencies=[], app_id=app_id) + redis_client.setex( + f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app_id}", + IMPORT_INFO_REDIS_EXPIRY, + pending.model_dump_json(), + ) + + dep = app_dsl_service.PluginDependency.model_validate( + { + "type": "package", + "value": { + "plugin_unique_identifier": "acme/foo", + "version": "1.0.0", + }, + } + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "get_leaked_dependencies", + lambda *, tenant_id, dependencies: [dep], + ) + + service = AppDslService(db_session_with_containers) + result = service.check_dependencies(app_model=SimpleNamespace(id=app_id, tenant_id=_DEFAULT_TENANT_ID)) + assert len(result.leaked_dependencies) == 1 + + def test_check_dependencies_with_real_app(self, db_session_with_containers, mock_external_service_dependencies): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}' + redis_client.setex( + f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}", + IMPORT_INFO_REDIS_EXPIRY, + mock_dependencies_json, + ) + + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.check_dependencies(app_model=app) + assert result.leaked_dependencies == [] + + # ── Create/Update App ───────────────────────────────────────────── + + def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="loss app mode"): + service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock()) + + def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers, monkeypatch): + fixed_now = object() + monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now) + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_environment_variable_from_mapping", + lambda _m: SimpleNamespace(kind="env"), + ) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_conversation_variable_from_mapping", + lambda _m: SimpleNamespace(kind="conv"), + ) + + app = SimpleNamespace( + id=str(uuid4()), + tenant_id=_DEFAULT_TENANT_ID, + mode=AppMode.WORKFLOW.value, + name="old", + description="old-desc", + icon_type=IconType.EMOJI, + icon="old-icon", + icon_background="#111111", + updated_by=None, + updated_at=None, + app_model_config=None, + ) + service = AppDslService(db_session_with_containers) + updated = service._create_or_update_app( + app=app, + data={ + "app": { + "mode": AppMode.WORKFLOW.value, + "name": "yaml-name", + "icon_type": IconType.IMAGE, + "icon": "X", + }, + "workflow": {"graph": {"nodes": []}, "features": {}}, + }, + account=_account_mock(), + name="override-name", + description=None, + icon_background="#222222", + ) + assert updated is app + assert app.name == "override-name" + assert app.icon_type == IconType.IMAGE + assert app.icon == "X" + assert app.icon_background == "#222222" + assert app.updated_at is fixed_now + + def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers): + account = _account_mock() + account.current_tenant_id = None + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Current tenant is not set"): + service._create_or_update_app( + app=None, + data={"app": {"mode": AppMode.WORKFLOW.value, "name": "n"}}, + account=account, + ) + + def test_create_or_update_app_creates_workflow_app_and_saves_dependencies( + self, db_session_with_containers, mock_external_service_dependencies + ): + _, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + mock_wf_svc = mock_external_service_dependencies["workflow_service"] + mock_wf_svc.return_value.get_draft_workflow.return_value = MagicMock(unique_hash="uh") + + service = AppDslService(db_session_with_containers) + deps = [ + app_dsl_service.PluginDependency.model_validate( + { + "type": "package", + "value": { + "plugin_unique_identifier": "acme/foo", + "version": "1.0.0", + }, + } + ) + ] + data = { + "app": {"mode": AppMode.WORKFLOW.value, "name": "n"}, + "workflow": { + "graph": {"nodes": []}, + "features": {}, + }, + } + + app = service._create_or_update_app(app=None, data=data, account=account, dependencies=deps) + + assert app.tenant_id == account.current_tenant_id + mock_external_service_dependencies["app_was_created"].send.assert_called_once() + mock_wf_svc.return_value.sync_draft_workflow.assert_called_once() + + stored = redis_client.get(f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}") + assert stored is not None + + def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Missing workflow data"): + service._create_or_update_app( + app=SimpleNamespace( + id=str(uuid4()), + tenant_id=_DEFAULT_TENANT_ID, + mode=AppMode.WORKFLOW.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.WORKFLOW.value}}, + account=_account_mock(), + ) + + def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Missing model_config"): + service._create_or_update_app( + app=SimpleNamespace( + id=str(uuid4()), + tenant_id=_DEFAULT_TENANT_ID, + mode=AppMode.CHAT.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.CHAT.value}}, + account=_account_mock(), + ) + + def test_create_or_update_app_chat_creates_model_config_and_sends_event( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + app.app_model_config_id = None + db_session_with_containers.commit() + + service = AppDslService(db_session_with_containers) + service._create_or_update_app( + app=app, + data={ + "app": {"mode": AppMode.CHAT.value}, + "model_config": {"model": {"provider": "openai"}}, + }, + account=account, + ) + + db_session_with_containers.expire_all() + assert app.app_model_config_id is not None + + def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Invalid app mode"): + service._create_or_update_app( + app=SimpleNamespace( + id=str(uuid4()), + tenant_id=_DEFAULT_TENANT_ID, + mode=AppMode.RAG_PIPELINE.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.RAG_PIPELINE.value}}, + account=_account_mock(), + ) + + # ── Export ───────────────────────────────────────────────────────── + + def test_export_dsl_delegates_by_mode(self, monkeypatch): + workflow_calls: list[bool] = [] + model_calls: list[bool] = [] + monkeypatch.setattr( + AppDslService, + "_append_workflow_export_data", + lambda **_kwargs: workflow_calls.append(True), + ) + monkeypatch.setattr( + AppDslService, + "_append_model_config_export_data", + lambda *_args, **_kwargs: model_calls.append(True), + ) + + workflow_app = SimpleNamespace( + mode=AppMode.WORKFLOW.value, + tenant_id=_DEFAULT_TENANT_ID, + name="n", + icon="i", + icon_type="emoji", + icon_background="#fff", + description="d", + use_icon_as_answer_icon=False, + app_model_config=None, + ) + AppDslService.export_dsl(workflow_app) + assert workflow_calls == [True] + + chat_app = SimpleNamespace( + mode=AppMode.CHAT.value, + tenant_id=_DEFAULT_TENANT_ID, + name="n", + icon="i", + icon_type="emoji", + icon_background="#fff", + description="d", + use_icon_as_answer_icon=False, + app_model_config=SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": []}}), + ) + AppDslService.export_dsl(chat_app) + assert model_calls == [True] + + def test_export_dsl_preserves_icon_and_icon_type(self, monkeypatch): + monkeypatch.setattr( + AppDslService, + "_append_workflow_export_data", + lambda **_kwargs: None, + ) + + emoji_app = SimpleNamespace( + mode=AppMode.WORKFLOW.value, + tenant_id=_DEFAULT_TENANT_ID, + name="Emoji App", + icon="🎨", + icon_type=IconType.EMOJI, + icon_background="#FF5733", + description="App with emoji icon", + use_icon_as_answer_icon=True, + app_model_config=None, + ) + yaml_output = AppDslService.export_dsl(emoji_app) + data = yaml.safe_load(yaml_output) + assert data["app"]["icon"] == "🎨" + assert data["app"]["icon_type"] == "emoji" + assert data["app"]["icon_background"] == "#FF5733" + + image_app = SimpleNamespace( + mode=AppMode.WORKFLOW.value, + tenant_id=_DEFAULT_TENANT_ID, + name="Image App", + icon="https://example.com/icon.png", + icon_type=IconType.IMAGE, + icon_background="#FFEAD5", + description="App with image icon", + use_icon_as_answer_icon=False, + app_model_config=None, + ) + yaml_output = AppDslService.export_dsl(image_app) + data = yaml.safe_load(yaml_output) + assert data["app"]["icon"] == "https://example.com/icon.png" + assert data["app"]["icon_type"] == "image" + assert data["app"]["icon_background"] == "#FFEAD5" + + def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - # Create model config for the app model_config = AppModelConfig( app_id=app.id, provider="openai", @@ -247,53 +920,38 @@ class TestAppDslService: created_by=account.id, updated_by=account.id, ) - model_config.id = fake.uuid4() - - # Set the app_model_config_id to link the config + model_config.id = str(uuid4()) app.app_model_config_id = model_config.id db_session_with_containers.add(model_config) db_session_with_containers.commit() - # Export DSL exported_dsl = AppDslService.export_dsl(app, include_secret=False) - - # Parse exported YAML exported_data = yaml.safe_load(exported_dsl) - # Verify exported data structure assert exported_data["kind"] == "app" assert exported_data["app"]["name"] == app.name assert exported_data["app"]["mode"] == app.mode - assert exported_data["app"]["icon"] == app.icon - assert exported_data["app"]["icon_background"] == app.icon_background - assert exported_data["app"]["description"] == app.description - - # Verify model config was exported assert "model_config" in exported_data - # The exported model_config structure may be different from the database structure - # Check that the model config exists and has the expected content - assert exported_data["model_config"] is not None - - # Verify dependencies were exported assert "dependencies" in exported_data - assert isinstance(exported_data["dependencies"], list) def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test successful DSL export for workflow app. - """ - fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - - # Update app to workflow mode app.mode = "workflow" db_session_with_containers.commit() - # Mock workflow service to return a workflow mock_workflow = MagicMock() mock_workflow.to_dict.return_value = { - "graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []}, + "graph": { + "nodes": [ + { + "id": "start", + "type": "start", + "data": {"type": "start"}, + } + ], + "edges": [], + }, "features": {}, "environment_variables": [], "conversation_variables": [], @@ -302,54 +960,40 @@ class TestAppDslService: "workflow_service" ].return_value.get_draft_workflow.return_value = mock_workflow - # Export DSL exported_dsl = AppDslService.export_dsl(app, include_secret=False) - - # Parse exported YAML exported_data = yaml.safe_load(exported_dsl) - # Verify exported data structure assert exported_data["kind"] == "app" - assert exported_data["app"]["name"] == app.name assert exported_data["app"]["mode"] == "workflow" - - # Verify workflow was exported assert "workflow" in exported_data - assert "graph" in exported_data["workflow"] - assert "nodes" in exported_data["workflow"]["graph"] - - # Verify dependencies were exported assert "dependencies" in exported_data - assert isinstance(exported_data["dependencies"], list) - - # Verify workflow service was called - mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with( - app, None - ) def test_export_dsl_with_workflow_id_success(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test successful DSL export with specific workflow ID. - """ - fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - - # Update app to workflow mode app.mode = "workflow" db_session_with_containers.commit() - # Mock workflow service to return a workflow when specific workflow_id is provided mock_workflow = MagicMock() mock_workflow.to_dict.return_value = { - "graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []}, + "graph": { + "nodes": [ + { + "id": "start", + "type": "start", + "data": {"type": "start"}, + } + ], + "edges": [], + }, "features": {}, "environment_variables": [], "conversation_variables": [], } - # Mock the get_draft_workflow method to return different workflows based on workflow_id - def mock_get_draft_workflow(app_model, workflow_id=None): - if workflow_id == "specific-workflow-id": + workflow_id = str(uuid4()) + + def mock_get_draft_workflow(app_model, wf_id=None): + if wf_id == workflow_id: return mock_workflow return None @@ -357,78 +1001,351 @@ class TestAppDslService: "workflow_service" ].return_value.get_draft_workflow.side_effect = mock_get_draft_workflow - # Export DSL with specific workflow ID - exported_dsl = AppDslService.export_dsl(app, include_secret=False, workflow_id="specific-workflow-id") - - # Parse exported YAML + exported_dsl = AppDslService.export_dsl(app, include_secret=False, workflow_id=workflow_id) exported_data = yaml.safe_load(exported_dsl) - # Verify exported data structure assert exported_data["kind"] == "app" - assert exported_data["app"]["name"] == app.name - assert exported_data["app"]["mode"] == "workflow" - - # Verify workflow was exported assert "workflow" in exported_data - assert "graph" in exported_data["workflow"] - assert "nodes" in exported_data["workflow"]["graph"] - - # Verify dependencies were exported - assert "dependencies" in exported_data - assert isinstance(exported_data["dependencies"], list) - - # Verify workflow service was called with specific workflow ID - mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with( - app, "specific-workflow-id" - ) def test_export_dsl_with_invalid_workflow_id_raises_error( self, db_session_with_containers, mock_external_service_dependencies ): - """ - Test that export_dsl raises error when invalid workflow ID is provided. - """ - fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - - # Update app to workflow mode app.mode = "workflow" db_session_with_containers.commit() - # Mock workflow service to return None when invalid workflow ID is provided mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.return_value = None - # Export DSL with invalid workflow ID should raise ValueError - with pytest.raises(ValueError, match="Missing draft workflow configuration, please check."): - AppDslService.export_dsl(app, include_secret=False, workflow_id="invalid-workflow-id") + with pytest.raises( + ValueError, + match="Missing draft workflow configuration, please check.", + ): + AppDslService.export_dsl(app, include_secret=False, workflow_id=str(uuid4())) - # Verify workflow service was called with the invalid workflow ID - mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with( - app, "invalid-workflow-id" + # ── Workflow Export Data ─────────────────────────────────────────── + + def test_append_workflow_export_data_filters_and_overrides(self, monkeypatch): + workflow_dict = { + "graph": { + "nodes": [ + { + "data": { + "type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, + "dataset_ids": ["d1", "d2"], + } + }, + { + "data": { + "type": BuiltinNodeTypes.TOOL, + "credential_id": "secret", + } + }, + { + "data": { + "type": BuiltinNodeTypes.AGENT, + "agent_parameters": {"tools": {"value": [{"credential_id": "secret"}]}}, + } + }, + { + "data": { + "type": TRIGGER_SCHEDULE_NODE_TYPE, + "config": {"x": 1}, + } + }, + { + "data": { + "type": TRIGGER_WEBHOOK_NODE_TYPE, + "webhook_url": "x", + "webhook_debug_url": "y", + } + }, + { + "data": { + "type": TRIGGER_PLUGIN_NODE_TYPE, + "subscription_id": "s", + } + }, + ] + } + } + + workflow = SimpleNamespace(to_dict=lambda *, include_secret: workflow_dict) + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + + monkeypatch.setattr( + AppDslService, + "encrypt_dataset_id", + lambda *, dataset_id, tenant_id: f"enc:{tenant_id}:{dataset_id}", + ) + monkeypatch.setattr( + app_dsl_service.TriggerScheduleNode, + "get_default_config", + lambda: {"config": {"default": True}}, + ) + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_workflow", + lambda *_args, **_kwargs: ["dep-1"], + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_dependencies", + lambda *, tenant_id, dependencies: [ + SimpleNamespace( + model_dump=lambda: { + "tenant": tenant_id, + "dep": dependencies[0], + } + ) + ], + ) + monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x) + + export_data: dict = {} + AppDslService._append_workflow_export_data( + export_data=export_data, + app_model=SimpleNamespace(tenant_id=_DEFAULT_TENANT_ID), + include_secret=False, + workflow_id=None, ) - def test_check_dependencies_success(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test successful dependency checking. - """ - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + nodes = export_data["workflow"]["graph"]["nodes"] + assert nodes[0]["data"]["dataset_ids"] == [ + f"enc:{_DEFAULT_TENANT_ID}:d1", + f"enc:{_DEFAULT_TENANT_ID}:d2", + ] + assert "credential_id" not in nodes[1]["data"] + assert "credential_id" not in nodes[2]["data"]["agent_parameters"]["tools"]["value"][0] + assert nodes[3]["data"]["config"] == {"default": True} + assert nodes[4]["data"]["webhook_url"] == "" + assert nodes[4]["data"]["webhook_debug_url"] == "" + assert nodes[5]["data"]["subscription_id"] == "" + assert export_data["dependencies"] == [{"tenant": _DEFAULT_TENANT_ID, "dep": "dep-1"}] - # Mock Redis to return dependencies - mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}' - mock_external_service_dependencies["redis_client"].get.return_value = mock_dependencies_json + def test_append_workflow_export_data_missing_workflow_raises(self, monkeypatch): + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) - # Check dependencies - dsl_service = AppDslService(db_session_with_containers) - result = dsl_service.check_dependencies(app_model=app) + with pytest.raises(ValueError, match="Missing draft workflow configuration"): + AppDslService._append_workflow_export_data( + export_data={}, + app_model=SimpleNamespace(tenant_id=_DEFAULT_TENANT_ID), + include_secret=False, + workflow_id=None, + ) - # Verify result - assert result.leaked_dependencies == [] + # ── Model Config Export Data ────────────────────────────────────── - # Verify Redis was queried - mock_external_service_dependencies["redis_client"].get.assert_called_once_with( - f"app_check_dependencies:{app.id}" + def test_append_model_config_export_data_filters_credential_id(self, monkeypatch): + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_model_config", + lambda *_args, **_kwargs: ["dep-1"], + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_dependencies", + lambda *, tenant_id, dependencies: [ + SimpleNamespace( + model_dump=lambda: { + "tenant": tenant_id, + "dep": dependencies[0], + } + ) + ], + ) + monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x) + + app_model_config = SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": [{"credential_id": "secret"}]}}) + app_model = SimpleNamespace(tenant_id=_DEFAULT_TENANT_ID, app_model_config=app_model_config) + export_data: dict = {} + + AppDslService._append_model_config_export_data(export_data, app_model) + assert export_data["model_config"]["agent_mode"]["tools"] == [{}] + assert export_data["dependencies"] == [{"tenant": _DEFAULT_TENANT_ID, "dep": "dep-1"}] + + def test_append_model_config_export_data_requires_app_config(self): + with pytest.raises(ValueError, match="Missing app configuration"): + AppDslService._append_model_config_export_data({}, SimpleNamespace(app_model_config=None)) + + # ── Dependency Extraction ───────────────────────────────────────── + + def test_extract_dependencies_from_workflow_graph_covers_all_node_types(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_tool_dependency", + lambda provider_id: f"tool:{provider_id}", + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda provider: f"model:{provider}", ) - # Verify dependencies service was called - mock_external_service_dependencies["dependencies_service"].get_leaked_dependencies.assert_called_once() + monkeypatch.setattr( + app_dsl_service.ToolNodeData, + "model_validate", + lambda _d: SimpleNamespace(provider_id="p1"), + ) + monkeypatch.setattr( + app_dsl_service.LLMNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m1")), + ) + monkeypatch.setattr( + app_dsl_service.QuestionClassifierNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m2")), + ) + monkeypatch.setattr( + app_dsl_service.ParameterExtractorNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m3")), + ) + + def kr_validate(_d): + return SimpleNamespace( + retrieval_mode="multiple", + multiple_retrieval_config=SimpleNamespace( + reranking_mode="weighted_score", + weights=SimpleNamespace(vector_setting=SimpleNamespace(embedding_provider_name="m4")), + reranking_model=None, + ), + single_retrieval_config=None, + ) + + monkeypatch.setattr( + app_dsl_service.KnowledgeRetrievalNodeData, + "model_validate", + kr_validate, + ) + + graph = { + "nodes": [ + {"data": {"type": BuiltinNodeTypes.TOOL}}, + {"data": {"type": BuiltinNodeTypes.LLM}}, + {"data": {"type": BuiltinNodeTypes.QUESTION_CLASSIFIER}}, + {"data": {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR}}, + {"data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}, + {"data": {"type": "unknown"}}, + ] + } + + deps = AppDslService._extract_dependencies_from_workflow_graph(graph) + assert deps == [ + "tool:p1", + "model:m1", + "model:m2", + "model:m3", + "model:m4", + ] + + def test_extract_dependencies_from_workflow_graph_handles_exceptions(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.ToolNodeData, + "model_validate", + lambda _d: (_ for _ in ()).throw(ValueError("bad")), + ) + deps = AppDslService._extract_dependencies_from_workflow_graph( + {"nodes": [{"data": {"type": BuiltinNodeTypes.TOOL}}]} + ) + assert deps == [] + + def test_extract_dependencies_from_model_config_parses_providers(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda provider: f"model:{provider}", + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_tool_dependency", + lambda provider_id: f"tool:{provider_id}", + ) + + deps = AppDslService._extract_dependencies_from_model_config( + { + "model": {"provider": "p1"}, + "dataset_configs": { + "datasets": {"datasets": [{"reranking_model": {"reranking_provider_name": {"provider": "p2"}}}]} + }, + "agent_mode": {"tools": [{"provider_id": "t1"}]}, + } + ) + assert deps == ["model:p1", "model:p2", "tool:t1"] + + def test_extract_dependencies_from_model_config_handles_exceptions(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda _p: (_ for _ in ()).throw(ValueError("bad")), + ) + deps = AppDslService._extract_dependencies_from_model_config({"model": {"provider": "p1"}}) + assert deps == [] + + # ── Leaked Dependencies ─────────────────────────────────────────── + + def test_get_leaked_dependencies_empty_returns_empty(self): + assert AppDslService.get_leaked_dependencies(_DEFAULT_TENANT_ID, []) == [] + + def test_get_leaked_dependencies_delegates(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "get_leaked_dependencies", + lambda *, tenant_id, dependencies: [SimpleNamespace(tenant_id=tenant_id, deps=dependencies)], + ) + res = AppDslService.get_leaked_dependencies(_DEFAULT_TENANT_ID, [SimpleNamespace(id="x")]) + assert len(res) == 1 + + # ── Encryption/Decryption ───────────────────────────────────────── + + def test_encrypt_decrypt_dataset_id_respects_config(self, monkeypatch): + tenant_id = _DEFAULT_TENANT_ID + dataset_uuid = "00000000-0000-0000-0000-000000000000" + + monkeypatch.setattr( + app_dsl_service.dify_config, + "DSL_EXPORT_ENCRYPT_DATASET_ID", + False, + ) + assert AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id) == dataset_uuid + + monkeypatch.setattr( + app_dsl_service.dify_config, + "DSL_EXPORT_ENCRYPT_DATASET_ID", + True, + ) + encrypted = AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id) + assert encrypted != dataset_uuid + assert base64.b64decode(encrypted.encode()) + assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id=tenant_id) == dataset_uuid + + def test_decrypt_dataset_id_returns_plain_uuid_unchanged(self): + value = "00000000-0000-0000-0000-000000000000" + assert AppDslService.decrypt_dataset_id(encrypted_data=value, tenant_id=_DEFAULT_TENANT_ID) == value + + def test_decrypt_dataset_id_returns_none_on_invalid_data(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.dify_config, + "DSL_EXPORT_ENCRYPT_DATASET_ID", + True, + ) + assert AppDslService.decrypt_dataset_id(encrypted_data="not-base64", tenant_id=_DEFAULT_TENANT_ID) is None + + def test_decrypt_dataset_id_returns_none_when_decrypted_is_not_uuid(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.dify_config, + "DSL_EXPORT_ENCRYPT_DATASET_ID", + True, + ) + encrypted = AppDslService.encrypt_dataset_id(dataset_id="not-a-uuid", tenant_id=_DEFAULT_TENANT_ID) + assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id=_DEFAULT_TENANT_ID) is None + + # ── Utility ─────────────────────────────────────────────────────── + + def test_is_valid_uuid_handles_bad_inputs(self): + assert AppDslService._is_valid_uuid("00000000-0000-0000-0000-000000000000") is True + assert AppDslService._is_valid_uuid("nope") is False diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index fa57dd4a6ff..b695ae9fd9e 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -658,15 +658,17 @@ class TestAppService: # Update app icon new_icon = "🌟" new_icon_background = "#FFD93D" + new_icon_type = "image" mock_current_user = create_autospec(Account, instance=True) mock_current_user.id = account.id mock_current_user.current_tenant_id = account.current_tenant_id with patch("services.app_service.current_user", mock_current_user): - updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) + updated_app = app_service.update_app_icon(app, new_icon, new_icon_background, new_icon_type) assert updated_app.icon == new_icon assert updated_app.icon_background == new_icon_background + assert str(updated_app.icon_type).lower() == new_icon_type assert updated_app.updated_by == account.id # Verify other fields remain unchanged diff --git a/api/tests/test_containers_integration_tests/services/test_audio_service_db.py b/api/tests/test_containers_integration_tests/services/test_audio_service_db.py new file mode 100644 index 00000000000..2593b53fe84 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_audio_service_db.py @@ -0,0 +1,211 @@ +""" +Integration tests for AudioService.transcript_tts message-ID path. + +Migrated from unit_tests/services/test_audio_service.py, replacing +db.session.get mock patches with real Message rows persisted in PostgreSQL. + +Covers: +- transcript_tts with valid message_id that resolves to a real Message +- transcript_tts returns None for invalid (non-UUID) message_id +- transcript_tts returns None when message_id is a valid UUID but no row exists +- transcript_tts returns None when message exists but has an empty answer +""" + +from collections.abc import Generator +from decimal import Decimal +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from models.account import TenantAccountJoin +from models.enums import ConversationFromSource, MessageStatus +from models.model import App, AppMode, Conversation, Message +from services.audio_service import AudioService +from tests.test_containers_integration_tests.controllers.console.helpers import ( + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation(db_session: Session, app: App, account_id: str) -> Conversation: + """Create a Conversation row via flush() so the rollback-based teardown can remove it.""" + conversation = Conversation( + app_id=app.id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name=f"Conversation {uuid4()}", + summary="", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=InvokeFrom.WEB_APP.value, + from_source=ConversationFromSource.CONSOLE, + from_end_user_id=None, + from_account_id=account_id, + dialogue_count=0, + is_deleted=False, + ) + db_session.add(conversation) + db_session.flush() + return conversation + + +def _create_message( + db_session: Session, + app: App, + conversation: Conversation, + account_id: str, + *, + answer: str = "Message answer text", + status: MessageStatus | str = MessageStatus.NORMAL, +) -> Message: + """Create a Message row via flush() so the rollback-based teardown can remove it.""" + message = Message( + app_id=app.id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation.id, + inputs={}, + query="Test query", + message={"messages": [{"role": "user", "content": "Test query"}]}, + message_tokens=0, + message_unit_price=Decimal(0), + message_price_unit=Decimal("0.001"), + answer=answer, + answer_tokens=0, + answer_unit_price=Decimal(0), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=0, + total_price=Decimal(0), + currency="USD", + status=status, + invoke_from=InvokeFrom.WEB_APP.value, + from_source=ConversationFromSource.CONSOLE, + from_end_user_id=None, + from_account_id=account_id, + ) + db_session.add(message) + db_session.flush() + return message + + +class TestAudioServiceTranscriptTTSMessageLookup: + """Integration tests for AudioService.transcript_tts message-ID lookup via real DB.""" + + @pytest.fixture(autouse=True) + def _setup_cleanup(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Track rows created by shared helpers that commit, then clean up after the test. + + The shared console helpers (create_console_account_and_tenant, create_console_app) + commit their inserts so the rows survive a simple rollback. This fixture records + the app/account/tenant created per test and explicitly deletes them after the test + so the DB does not accumulate state across tests. Conversation/Message rows are + created via flush() only, so the trailing rollback removes them. + """ + self._committed_rows: list = [] + yield + db_session_with_containers.rollback() + for entity in reversed(self._committed_rows): + db_session_with_containers.execute(delete(type(entity)).where(type(entity).id == entity.id)) + db_session_with_containers.commit() + + def _setup_app_and_account(self, db_session: Session) -> tuple[App, str, str]: + """Create committed app/account/tenant using shared helpers and track them for cleanup.""" + account, tenant = create_console_account_and_tenant(db_session) + app = create_console_app(db_session, tenant_id=tenant.id, account_id=account.id, mode=AppMode.CHAT) + + # Track rows in the order they must be deleted (FK-safe: app and join before account/tenant) + self._committed_rows.append(app) + join = db_session.scalar( + select(TenantAccountJoin).where( + TenantAccountJoin.account_id == account.id, + TenantAccountJoin.tenant_id == tenant.id, + ) + ) + if join is not None: + self._committed_rows.append(join) + self._committed_rows.extend([account, tenant]) + return app, account.id, tenant.id + + def test_transcript_tts_with_message_id_success(self, db_session_with_containers: Session) -> None: + """transcript_tts invokes TTS with the message answer when message_id resolves to a real row.""" + app, account_id, _ = self._setup_app_and_account(db_session_with_containers) + conversation = _create_conversation(db_session_with_containers, app, account_id) + message = _create_message( + db_session_with_containers, + app, + conversation, + account_id, + answer="Hello from message", + ) + + mock_model_instance = MagicMock() + mock_model_instance.invoke_tts.return_value = b"audio from message" + mock_model_manager = MagicMock() + mock_model_manager.get_default_model_instance.return_value = mock_model_instance + + with patch("services.audio_service.ModelManager.for_tenant", return_value=mock_model_manager): + result = AudioService.transcript_tts( + app_model=app, + message_id=message.id, + voice="en-US-Neural", + ) + + assert result == b"audio from message" + mock_model_instance.invoke_tts.assert_called_once_with( + content_text="Hello from message", + voice="en-US-Neural", + ) + + def test_transcript_tts_returns_none_for_invalid_message_id(self, db_session_with_containers: Session) -> None: + """transcript_tts returns None immediately when message_id is not a valid UUID.""" + app, _, _ = self._setup_app_and_account(db_session_with_containers) + + result = AudioService.transcript_tts( + app_model=app, + message_id="invalid-uuid", + ) + + assert result is None + + def test_transcript_tts_returns_none_for_nonexistent_message(self, db_session_with_containers: Session) -> None: + """transcript_tts returns None when message_id is a valid UUID but no Message row exists.""" + app, _, _ = self._setup_app_and_account(db_session_with_containers) + + result = AudioService.transcript_tts( + app_model=app, + message_id=str(uuid4()), + ) + + assert result is None + + def test_transcript_tts_returns_none_for_empty_message_answer(self, db_session_with_containers: Session) -> None: + """transcript_tts returns None when the resolved message has an empty answer.""" + app, account_id, _ = self._setup_app_and_account(db_session_with_containers) + conversation = _create_conversation(db_session_with_containers, app, account_id) + message = _create_message( + db_session_with_containers, + app, + conversation, + account_id, + answer="", + status=MessageStatus.NORMAL, + ) + + result = AudioService.transcript_tts( + app_model=app, + message_id=message.id, + ) + + assert result is None diff --git a/api/tests/test_containers_integration_tests/services/test_billing_service.py b/api/tests/test_containers_integration_tests/services/test_billing_service.py index 76708b36b1c..8092c7ad757 100644 --- a/api/tests/test_containers_integration_tests/services/test_billing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_billing_service.py @@ -1,9 +1,13 @@ import json +from collections.abc import Generator from unittest.mock import patch +from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from services.billing_service import BillingService @@ -363,3 +367,62 @@ class TestBillingServiceGetPlanBulkWithCache: assert ttl_1_new <= 600 assert ttl_2 > 0 assert ttl_2 <= 600 + + +class TestBillingServiceIsTenantOwnerOrAdmin: + """ + Integration tests for BillingService.is_tenant_owner_or_admin. + + Verifies that non-privileged roles (EDITOR, DATASET_OPERATOR) raise ValueError + when checked against real TenantAccountJoin rows in PostgreSQL. + """ + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _create_account_with_tenant_role(self, db_session: Session, role: TenantAccountRole) -> tuple[Account, Tenant]: + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session.add(tenant) + db_session.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"billing_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session.add(account) + db_session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session.add(join) + db_session.flush() + + # Wire up in-memory reference so current_tenant_id resolves + account._current_tenant = tenant + return account, tenant + + def test_is_tenant_owner_or_admin_editor_role_raises_error(self, db_session_with_containers: Session) -> None: + """is_tenant_owner_or_admin raises ValueError for EDITOR role.""" + account, _ = self._create_account_with_tenant_role(db_session_with_containers, TenantAccountRole.EDITOR) + + with pytest.raises(ValueError, match="Only team owner or team admin can perform this action"): + BillingService.is_tenant_owner_or_admin(account) + + def test_is_tenant_owner_or_admin_dataset_operator_raises_error(self, db_session_with_containers: Session) -> None: + """is_tenant_owner_or_admin raises ValueError for DATASET_OPERATOR role.""" + account, _ = self._create_account_with_tenant_role( + db_session_with_containers, TenantAccountRole.DATASET_OPERATOR + ) + + with pytest.raises(ValueError, match="Only team owner or team admin can perform this action"): + BillingService.is_tenant_owner_or_admin(account) diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_conversation_service.py index 6180d98b1ef..98c38f2b5fa 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service.py @@ -637,6 +637,40 @@ class TestConversationServiceSummarization: assert conversation.name == new_name assert conversation.updated_at == mock_time + @patch("services.conversation_service.LLMGenerator.generate_conversation_name") + def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers): + """ + Test rename delegates to auto_generate_name when auto_generate is True. + + When auto_generate is True, the service should call auto_generate_name + which uses an LLM to create a descriptive conversation title. + """ + # Arrange + app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( + db_session_with_containers + ) + conversation = ConversationServiceIntegrationTestDataFactory.create_conversation( + db_session_with_containers, app_model, user + ) + ConversationServiceIntegrationTestDataFactory.create_message( + db_session_with_containers, app_model, conversation, user + ) + generated_name = "Auto Generated Name" + mock_llm_generator.return_value = generated_name + + # Act + result = ConversationService.rename( + app_model=app_model, + conversation_id=conversation.id, + user=user, + name=None, + auto_generate=True, + ) + + # Assert + assert result == conversation + assert conversation.name == generated_name + class TestConversationServiceMessageAnnotation: """ @@ -1066,3 +1100,32 @@ class TestConversationServiceExport: not_deleted = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation.id)) assert not_deleted is not None mock_delete_task.delay.assert_not_called() + + @patch("services.conversation_service.delete_conversation_related_data") + def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers): + """ + Test that delete propagates exceptions and does not trigger the cleanup task. + + When a DB error occurs during deletion, the service must rollback the + transaction and re-raise the exception without scheduling async cleanup. + """ + # Arrange + app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( + db_session_with_containers + ) + conversation = ConversationServiceIntegrationTestDataFactory.create_conversation( + db_session_with_containers, app_model, user + ) + conversation_id = conversation.id + + # Act — force an error during the delete to exercise the rollback path + with patch("services.conversation_service.db.session.delete", side_effect=Exception("DB error")): + with pytest.raises(Exception, match="DB error"): + ConversationService.delete(app_model=app_model, conversation_id=conversation_id, user=user) + + # Assert — async cleanup must NOT have been scheduled + mock_delete_task.delay.assert_not_called() + + # Conversation is still present because the deletion was never committed + still_there = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation_id)) + assert still_there is not None diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py new file mode 100644 index 00000000000..0b7bd9ca645 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py @@ -0,0 +1,524 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import sessionmaker + +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from graphon.variables import FloatVariable, IntegerVariable, StringVariable +from models.account import Account, Tenant, TenantAccountJoin +from models.enums import ConversationFromSource +from models.model import App, Conversation, EndUser +from models.workflow import ConversationVariable +from services.conversation_service import ConversationService +from services.errors.conversation import ( + ConversationVariableNotExistsError, + ConversationVariableTypeMismatchError, + LastConversationNotExistsError, +) + + +class ConversationServiceVariableIntegrationFactory: + @staticmethod + def create_app_and_account(db_session_with_containers): + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session_with_containers.add(tenant) + db_session_with_containers.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"conversation-variable-{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + db_session_with_containers.add(tenant_join) + db_session_with_containers.flush() + + app = App( + tenant_id=tenant.id, + name=f"App {uuid4()}", + description="", + mode="chat", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + + return app, account + + @staticmethod + def create_end_user(db_session_with_containers, app: App): + end_user = EndUser( + tenant_id=app.tenant_id, + app_id=app.id, + type=InvokeFrom.SERVICE_API.value, + external_user_id=f"external-{uuid4()}", + name=f"End User {uuid4()}", + is_anonymous=False, + session_id=f"session-{uuid4()}", + ) + db_session_with_containers.add(end_user) + db_session_with_containers.commit() + return end_user + + @staticmethod + def create_conversation( + db_session_with_containers, + app: App, + user: Account | EndUser, + *, + name: str | None = None, + invoke_from: InvokeFrom = InvokeFrom.WEB_APP, + created_at: datetime | None = None, + updated_at: datetime | None = None, + ) -> Conversation: + conversation = Conversation( + app_id=app.id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name=name or f"Conversation {uuid4()}", + summary="", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=invoke_from.value, + from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE, + from_end_user_id=user.id if isinstance(user, EndUser) else None, + from_account_id=user.id if isinstance(user, Account) else None, + dialogue_count=0, + is_deleted=False, + ) + conversation.inputs = {} + if created_at is not None: + conversation.created_at = created_at + if updated_at is not None: + conversation.updated_at = updated_at + + db_session_with_containers.add(conversation) + db_session_with_containers.commit() + return conversation + + @staticmethod + def create_variable( + db_session_with_containers, + *, + app: App, + conversation: Conversation, + variable: StringVariable | FloatVariable | IntegerVariable, + created_at: datetime | None = None, + ) -> ConversationVariable: + row = ConversationVariable.from_variable(app_id=app.id, conversation_id=conversation.id, variable=variable) + if created_at is not None: + row.created_at = created_at + row.updated_at = created_at + + db_session_with_containers.add(row) + db_session_with_containers.commit() + return row + + +@pytest.fixture +def real_conversation_service_session_factory(flask_app_with_containers): + del flask_app_with_containers + real_session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + + with ( + patch("services.conversation_service.session_factory.create_session", side_effect=lambda: real_session_maker()), + patch("services.conversation_service.session_factory.get_session_maker", return_value=real_session_maker), + ): + yield + + +class TestConversationServiceVariables: + def test_get_conversational_variable_success( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + older_time = datetime(2024, 1, 1, 12, 0, 0) + newer_time = older_time + timedelta(minutes=5) + + first_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="topic", value="billing"), + created_at=older_time, + ) + second_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="priority", value="high"), + created_at=newer_time, + ) + + result = ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=10, + last_id=None, + ) + + assert [item["id"] for item in result.data] == [first_variable.id, second_variable.id] + assert [item["name"] for item in result.data] == ["topic", "priority"] + assert result.limit == 10 + assert result.has_more is False + + def test_get_conversational_variable_with_last_id( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + base_time = datetime(2024, 1, 1, 9, 0, 0) + + first_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="topic", value="billing"), + created_at=base_time, + ) + second_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="priority", value="high"), + created_at=base_time + timedelta(minutes=1), + ) + third_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="owner", value="alice"), + created_at=base_time + timedelta(minutes=2), + ) + + result = ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=10, + last_id=first_variable.id, + ) + + assert [item["id"] for item in result.data] == [second_variable.id, third_variable.id] + assert result.has_more is False + + def test_get_conversational_variable_last_id_not_found_raises_error( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=10, + last_id=str(uuid4()), + ) + + def test_get_conversational_variable_sets_has_more( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + + for index in range(3): + factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name=f"var_{index}", value=f"value_{index}"), + created_at=datetime(2024, 1, 1, 10, 0, index), + ) + + result = ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=2, + last_id=None, + ) + + assert len(result.data) == 2 + assert result.has_more is True + + def test_update_conversation_variable_success( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + existing = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="topic", value="billing"), + ) + updated_at = datetime(2024, 1, 1, 15, 0, 0) + + with patch("services.conversation_service.naive_utc_now", return_value=updated_at): + result = ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=existing.id, + user=account, + new_value="support", + ) + + db_session_with_containers.expire_all() + persisted = db_session_with_containers.get(ConversationVariable, (existing.id, conversation.id)) + + assert persisted is not None + assert persisted.to_variable().value == "support" + assert result["id"] == existing.id + assert result["value"] == "support" + assert result["updated_at"] == updated_at + + def test_update_conversation_variable_not_found_raises_error( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=str(uuid4()), + user=account, + new_value="support", + ) + + def test_update_conversation_variable_type_mismatch_raises_error( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + existing = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=FloatVariable(id=str(uuid4()), name="score", value=1.5), + ) + + with pytest.raises(ConversationVariableTypeMismatchError, match="expects float"): + ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=existing.id, + user=account, + new_value="wrong-type", + ) + + def test_update_conversation_variable_integer_number_compatibility( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + existing = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=IntegerVariable(id=str(uuid4()), name="attempts", value=1), + ) + + result = ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=existing.id, + user=account, + new_value=42, + ) + + db_session_with_containers.expire_all() + persisted = db_session_with_containers.get(ConversationVariable, (existing.id, conversation.id)) + + assert persisted is not None + assert persisted.to_variable().value == 42 + assert result["value"] == 42 + + +class TestConversationServicePaginationWithContainers: + def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + + with pytest.raises(LastConversationNotExistsError): + ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=str(uuid4()), + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + base_time = datetime(2024, 1, 1, 8, 0, 0) + newest = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Newest", + updated_at=base_time + timedelta(minutes=2), + ) + middle = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Middle", + updated_at=base_time + timedelta(minutes=1), + ) + oldest = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Oldest", + updated_at=base_time, + ) + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=middle.id, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + ) + + assert newest.id != middle.id + assert [conversation.id for conversation in result.data] == [oldest.id] + + def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + alpha = factory.create_conversation(db_session_with_containers, app, account, name="Alpha") + beta = factory.create_conversation(db_session_with_containers, app, account, name="Beta") + gamma = factory.create_conversation(db_session_with_containers, app, account, name="Gamma") + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=beta.id, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + sort_by="name", + ) + + assert alpha.id != beta.id + assert [conversation.id for conversation in result.data] == [gamma.id] + + def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + end_user = factory.create_end_user(db_session_with_containers, app) + account_conversation = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Console Conversation", + invoke_from=InvokeFrom.WEB_APP, + ) + end_user_conversation = factory.create_conversation( + db_session_with_containers, + app, + end_user, + name="API Conversation", + invoke_from=InvokeFrom.SERVICE_API, + ) + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=end_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert account_conversation.id != end_user_conversation.id + assert [conversation.id for conversation in result.data] == [end_user_conversation.id] + + def test_pagination_filters_to_account_console_source(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + end_user = factory.create_end_user(db_session_with_containers, app) + account_conversation = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Console Conversation", + invoke_from=InvokeFrom.WEB_APP, + ) + factory.create_conversation( + db_session_with_containers, + app, + end_user, + name="API Conversation", + invoke_from=InvokeFrom.SERVICE_API, + ) + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + assert [conversation.id for conversation in result.data] == [account_conversation.id] diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py index fb0adbbcc26..02ab3f83146 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -3,10 +3,10 @@ from uuid import uuid4 import pytest -from graphon.variables import StringVariable from sqlalchemy.orm import sessionmaker from extensions.ext_database import db +from graphon.variables import StringVariable from models.workflow import ConversationVariable from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index f9bfa570cbb..0de3c64c4f7 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,11 +9,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py new file mode 100644 index 00000000000..2bec703f0cc --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py @@ -0,0 +1,650 @@ +"""Testcontainers integration tests for SQL-backed DocumentService paths.""" + +import datetime +import json +from unittest.mock import create_autospec, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +from core.rag.index_processor.constant.index_type import IndexStructureType +from extensions.storage.storage_type import StorageType +from models import Account +from models.dataset import Dataset, Document +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus +from models.model import UploadFile +from services.dataset_service import DocumentService +from services.errors.account import NoPermissionError + +FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0) + + +class DocumentServiceIntegrationFactory: + @staticmethod + def create_dataset( + db_session_with_containers, + *, + tenant_id: str | None = None, + created_by: str | None = None, + name: str | None = None, + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id or str(uuid4()), + name=name or f"dataset-{uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=created_by or str(uuid4()), + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_document( + db_session_with_containers, + *, + dataset: Dataset, + name: str = "doc.txt", + position: int = 1, + tenant_id: str | None = None, + indexing_status: str = IndexingStatus.COMPLETED, + enabled: bool = True, + archived: bool = False, + is_paused: bool = False, + need_summary: bool = False, + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + batch: str | None = None, + data_source_type: str = DataSourceType.UPLOAD_FILE, + data_source_info: dict | None = None, + created_by: str | None = None, + ) -> Document: + document = Document( + tenant_id=tenant_id or dataset.tenant_id, + dataset_id=dataset.id, + position=position, + data_source_type=data_source_type, + data_source_info=json.dumps(data_source_info or {}), + batch=batch or f"batch-{uuid4()}", + name=name, + created_from=DocumentCreatedFrom.WEB, + created_by=created_by or dataset.created_by, + doc_form=doc_form, + ) + document.indexing_status = indexing_status + document.enabled = enabled + document.archived = archived + document.is_paused = is_paused + document.need_summary = need_summary + if indexing_status == IndexingStatus.COMPLETED: + document.completed_at = FIXED_UPLOAD_CREATED_AT + db_session_with_containers.add(document) + db_session_with_containers.commit() + return document + + @staticmethod + def create_upload_file( + db_session_with_containers, + *, + tenant_id: str, + created_by: str, + file_id: str | None = None, + name: str = "source.txt", + ) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=StorageType.LOCAL, + key=f"uploads/{uuid4()}", + name=name, + size=128, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=created_by, + created_at=FIXED_UPLOAD_CREATED_AT, + used=False, + ) + if file_id: + upload_file.id = file_id + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() + return upload_file + + +@pytest.fixture +def current_user_mock(): + with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user: + current_user.id = str(uuid4()) + current_user.current_tenant_id = str(uuid4()) + current_user.current_role = None + yield current_user + + +def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.get_document(dataset.id, None) is None + + +def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset) + + result = DocumentService.get_document(dataset.id, document.id) + + assert result is not None + assert result.id == document.id + + +def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + result = DocumentService.get_documents_by_ids(dataset.id, []) + + assert result == [] + + +def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + doc_a = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, name="a.txt") + doc_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + name="b.txt", + position=2, + ) + + result = DocumentService.get_documents_by_ids(dataset.id, [doc_a.id, doc_b.id]) + + assert {document.id for document in result} == {doc_a.id, doc_b.id} + + +def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.update_documents_need_summary(dataset.id, []) == 0 + + +def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + paragraph_doc = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + need_summary=True, + ) + qa_doc = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + need_summary=True, + doc_form=IndexStructureType.QA_INDEX, + ) + + updated_count = DocumentService.update_documents_need_summary( + dataset.id, + [paragraph_doc.id, qa_doc.id], + need_summary=False, + ) + + db_session_with_containers.expire_all() + refreshed_paragraph = db_session_with_containers.get(Document, paragraph_doc.id) + refreshed_qa = db_session_with_containers.get(Document, qa_doc.id) + assert updated_count == 1 + assert refreshed_paragraph is not None + assert refreshed_qa is not None + assert refreshed_paragraph.need_summary is False + assert refreshed_qa.need_summary is True + + +def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + with patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url: + result = DocumentService.get_document_download_url(document) + + assert result == "signed-url" + get_url.assert_called_once_with(upload_file_id=upload_file.id, as_attachment=True) + + +def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_type=DataSourceType.WEBSITE_CRAWL, + data_source_info={"url": "https://example.com"}, + ) + + with pytest.raises(NotFound, match="invalid source"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + +def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={}, + ) + + with pytest.raises(NotFound, match="missing file"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + +def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": 99}, + ) + + result = DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + assert result == "99" + + +def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": "missing-file"}, + ) + + with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}): + with pytest.raises(NotFound, match="Uploaded file not found"): + DocumentService._get_upload_file_for_upload_file_document(document) + + +def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + result = DocumentService._get_upload_file_for_upload_file_document(document) + + assert result.id == upload_file.id + + +def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + with pytest.raises(NotFound, match="Document not found"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[str(uuid4())], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + tenant_id=str(uuid4()), + data_source_info={"upload_file_id": upload_file.id}, + ) + + with pytest.raises(Forbidden, match="No permission"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document.id], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": str(uuid4())}, + ) + + with pytest.raises(NotFound, match="Only uploaded-file documents can be downloaded as ZIP"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document.id], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + mapping = DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document_a.id, document_b.id], + tenant_id=dataset.tenant_id, + ) + + assert mapping[document_a.id].id == upload_file_a.id + assert mapping[document_b.id].id == upload_file_b.id + + +def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset( + current_user_mock, flask_app_with_containers +): + with flask_app_with_containers.app_context(): + with pytest.raises(NotFound, match="Dataset not found"): + DocumentService.prepare_document_batch_download_zip( + dataset_id=str(uuid4()), + document_ids=[str(uuid4())], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + +def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden( + db_session_with_containers, + current_user_mock, +): + dataset = DocumentServiceIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=current_user_mock.current_tenant_id, + created_by=current_user_mock.id, + ) + + with patch( + "services.dataset_service.DatasetService.check_dataset_permission", + side_effect=NoPermissionError("denied"), + ): + with pytest.raises(Forbidden, match="denied"): + DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=[], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + +def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order( + db_session_with_containers, + current_user_mock, +): + dataset = DocumentServiceIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=current_user_mock.current_tenant_id, + created_by=current_user_mock.id, + ) + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + upload_files, download_name = DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=[document_b.id, document_a.id], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + assert [upload_file.id for upload_file in upload_files] == [upload_file_b.id, upload_file_a.id] + assert download_name.endswith(".zip") + + +def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + enabled_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=True, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + enabled=False, + ) + + result = DocumentService.get_document_by_dataset_id(dataset.id) + + assert [document.id for document in result] == [enabled_document.id] + + +def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + available_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + indexing_status=IndexingStatus.COMPLETED, + enabled=True, + archived=False, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + indexing_status=IndexingStatus.ERROR, + ) + + result = DocumentService.get_working_documents_by_dataset_id(dataset.id) + + assert [document.id for document in result] == [available_document.id] + + +def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + error_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + indexing_status=IndexingStatus.ERROR, + ) + paused_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + indexing_status=IndexingStatus.PAUSED, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=3, + indexing_status=IndexingStatus.COMPLETED, + ) + + result = DocumentService.get_error_documents_by_dataset_id(dataset.id) + + assert {document.id for document in result} == {error_document.id, paused_document.id} + + +def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + batch = f"batch-{uuid4()}" + matching_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + batch=batch, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + tenant_id=str(uuid4()), + batch=batch, + ) + + with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user: + current_user.current_tenant_id = dataset.tenant_id + result = DocumentService.get_batch_documents(dataset.id, batch) + + assert [document.id for document in result] == [matching_document.id] + + +def test_get_document_file_detail_returns_upload_file(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + + result = DocumentService.get_document_file_detail(upload_file.id) + + assert result is not None + assert result.id == upload_file.id + + +def test_delete_document_emits_signal_and_commits(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + with patch("services.dataset_service.document_was_deleted.send") as signal_send: + DocumentService.delete_document(document) + + assert db_session_with_containers.get(Document, document.id) is None + signal_send.assert_called_once_with( + document.id, + dataset_id=document.dataset_id, + doc_form=document.doc_form, + file_id=upload_file.id, + ) + + +def test_delete_documents_ignores_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + with patch("services.dataset_service.batch_clean_document_task.delay") as delay: + DocumentService.delete_documents(dataset, []) + + delay.assert_not_called() + + +def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + dataset.chunk_structure = IndexStructureType.PARAGRAPH_INDEX + db_session_with_containers.commit() + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + with patch("services.dataset_service.batch_clean_document_task.delay") as delay: + DocumentService.delete_documents(dataset, [document_a.id, document_b.id]) + + assert db_session_with_containers.get(Document, document_a.id) is None + assert db_session_with_containers.get(Document, document_b.id) is None + delay.assert_called_once() + args = delay.call_args.args + assert args[0] == [document_a.id, document_b.id] + assert args[1] == dataset.id + assert set(args[3]) == {upload_file_a.id, upload_file_b.id} + + +def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, position=3) + + assert DocumentService.get_documents_position(dataset.id) == 4 + + +def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.get_documents_position(dataset.id) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py new file mode 100644 index 00000000000..1b4179c9c70 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py @@ -0,0 +1,613 @@ +"""Testcontainers integration tests for DatasetService permission and lifecycle SQL paths.""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session +from werkzeug.exceptions import NotFound + +from core.rag.index_processor.constant.index_type import IndexTechniqueType +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import ( + AppDatasetJoin, + Dataset, + DatasetAutoDisableLog, + DatasetCollectionBinding, + DatasetPermission, + DatasetPermissionEnum, +) +from models.enums import DataSourceType +from services.dataset_service import DatasetCollectionBindingService, DatasetPermissionService, DatasetService +from services.errors.account import NoPermissionError + + +class DatasetPermissionIntegrationFactory: + @staticmethod + def create_account_with_tenant( + db_session_with_containers: Session, + role: TenantAccountRole = TenantAccountRole.OWNER, + ) -> tuple[Account, Tenant]: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.role = role + account._current_tenant = tenant + return account, tenant + + @staticmethod + def create_account_in_tenant( + db_session_with_containers: Session, + tenant: Tenant, + role: TenantAccountRole = TenantAccountRole.EDITOR, + ) -> Account: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.role = role + account._current_tenant = tenant + return account + + @staticmethod + def create_dataset( + db_session_with_containers: Session, + *, + tenant_id: str, + created_by: str, + name: str | None = None, + permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, + indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY, + enable_api: bool = True, + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=name or f"dataset-{uuid4()}", + description="desc", + data_source_type=DataSourceType.UPLOAD_FILE, + indexing_technique=indexing_technique, + created_by=created_by, + provider="vendor", + permission=permission, + retrieval_model={"top_k": 2}, + ) + dataset.enable_api = enable_api + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_dataset_permission( + db_session_with_containers: Session, + *, + dataset_id: str, + tenant_id: str, + account_id: str, + ) -> DatasetPermission: + permission = DatasetPermission( + dataset_id=dataset_id, + tenant_id=tenant_id, + account_id=account_id, + has_permission=True, + ) + db_session_with_containers.add(permission) + db_session_with_containers.commit() + return permission + + @staticmethod + def create_app_dataset_join( + db_session_with_containers: Session, + *, + dataset_id: str, + ) -> AppDatasetJoin: + join = AppDatasetJoin( + app_id=str(uuid4()), + dataset_id=dataset_id, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + return join + + @staticmethod + def create_collection_binding( + db_session_with_containers: Session, + *, + provider_name: str, + model_name: str, + collection_type: str = "dataset", + ) -> DatasetCollectionBinding: + binding = DatasetCollectionBinding( + provider_name=provider_name, + model_name=model_name, + collection_name=f"collection_{uuid4().hex}", + type=collection_type, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + return binding + + @staticmethod + def create_auto_disable_log( + db_session_with_containers: Session, + *, + tenant_id: str, + dataset_id: str, + document_id: str, + ) -> DatasetAutoDisableLog: + log = DatasetAutoDisableLog( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + ) + db_session_with_containers.add(log) + db_session_with_containers.commit() + return log + + +class TestDatasetServicePermissionsAndLifecycle: + def test_delete_dataset_returns_false_when_dataset_is_missing(self, db_session_with_containers: Session): + owner, _tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + + result = DatasetService.delete_dataset(str(uuid4()), user=owner) + + assert result is False + + def test_delete_dataset_checks_permission_and_deletes_dataset(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + with patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal: + result = DatasetService.delete_dataset(dataset.id, user=owner) + + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + send_deleted_signal.assert_called_once_with(dataset) + + def test_dataset_use_check_returns_true_when_join_exists(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + DatasetPermissionIntegrationFactory.create_app_dataset_join( + db_session_with_containers, + dataset_id=dataset.id, + ) + + assert DatasetService.dataset_use_check(dataset.id) is True + + def test_dataset_use_check_returns_false_when_join_missing(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + assert DatasetService.dataset_use_check(dataset.id) is False + + def test_check_dataset_permission_rejects_cross_tenant_access(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + outsider, _other_tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant( + db_session_with_containers + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, outsider) + + def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_permission_rejects_partial_team_user_without_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_permission_allows_partial_team_creator(self, db_session_with_containers: Session): + creator, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=creator.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + DatasetService.check_dataset_permission(dataset, creator) + + def test_check_dataset_permission_allows_partial_team_member_with_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member.id, + ) + + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_operator_permission_rejects_only_me_for_non_creator( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_check_dataset_operator_permission_rejects_partial_team_without_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_check_dataset_operator_permission_allows_partial_team_with_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=operator.id, + ) + + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers): + with flask_app_with_containers.app_context(): + with pytest.raises(NotFound, match="Dataset not found"): + DatasetService.update_dataset_api_status(str(uuid4()), True) + + def test_update_dataset_api_status_requires_current_user_id(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + enable_api=False, + ) + + with patch("services.dataset_service.current_user", SimpleNamespace(id=None)): + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.update_dataset_api_status(dataset.id, True) + + def test_update_dataset_api_status_updates_fields_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + enable_api=False, + ) + now = datetime(2026, 4, 14, 18, 0, 0) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.naive_utc_now", return_value=now), + ): + DatasetService.update_dataset_api_status(dataset.id, True) + + db_session_with_containers.refresh(dataset) + assert dataset.enable_api is True + assert dataset.updated_by == owner.id + assert dataset.updated_at == now + + def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan="professional")) + ) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + ): + result = DatasetService.get_dataset_auto_disable_logs(str(uuid4())) + + assert result == {"document_ids": [], "count": 0} + + def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + DatasetPermissionIntegrationFactory.create_auto_disable_log( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=str(uuid4()), + ) + DatasetPermissionIntegrationFactory.create_auto_disable_log( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=str(uuid4()), + ) + features = SimpleNamespace( + billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan="professional")) + ) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + ): + result = DatasetService.get_dataset_auto_disable_logs(dataset.id) + + assert result["count"] == 2 + assert len(result["document_ids"]) == 2 + + +class TestDatasetCollectionBindingServiceIntegration: + def test_get_dataset_collection_binding_returns_existing_binding(self, db_session_with_containers: Session): + binding = DatasetPermissionIntegrationFactory.create_collection_binding( + db_session_with_containers, + provider_name="provider", + model_name="model", + ) + + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model") + + assert result.id == binding.id + + def test_get_dataset_collection_binding_creates_binding_when_missing(self, db_session_with_containers: Session): + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "missing-model") + + persisted = db_session_with_containers.get(DatasetCollectionBinding, result.id) + assert persisted is not None + assert persisted.provider_name == "provider" + assert persisted.model_name == "missing-model" + assert persisted.type == "dataset" + assert persisted.collection_name + + def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers): + with flask_app_with_containers.app_context(): + with pytest.raises(ValueError, match="Dataset collection binding not found"): + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4())) + + def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self, db_session_with_containers: Session): + binding = DatasetPermissionIntegrationFactory.create_collection_binding( + db_session_with_containers, + provider_name="provider", + model_name="model", + ) + + result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id) + + assert result.id == binding.id + + +class TestDatasetPermissionServiceIntegration: + def test_get_dataset_partial_member_list_returns_scalar_results(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member_a.id, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member_b.id, + ) + + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + assert set(result) == {member_a.id, member_b.id} + + def test_update_partial_member_list_replaces_permissions_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + stale_member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=stale_member.id, + ) + + DatasetPermissionService.update_partial_member_list( + tenant.id, + dataset.id, + [{"user_id": member_a.id}, {"user_id": member_b.id}], + ) + + permissions = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all() + assert {permission.account_id for permission in permissions} == {member_a.id, member_b.id} + + def test_check_permission_requires_dataset_editor(self): + user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM) + + with pytest.raises(NoPermissionError, match="does not have permission"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ALL_TEAM, []) + + def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM) + + with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ONLY_ME, []) + + def test_check_permission_requires_partial_member_list_for_partial_members_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with pytest.raises(ValueError, match="Partial member list is required"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.PARTIAL_TEAM, []) + + def test_check_permission_rejects_dataset_operator_member_list_changes(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + with pytest.raises(ValueError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission( + user, + dataset, + DatasetPermissionEnum.PARTIAL_TEAM, + [{"user_id": "user-2"}], + ) + + def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + DatasetPermissionService.check_permission( + user, + dataset, + DatasetPermissionEnum.PARTIAL_TEAM, + [{"user_id": "user-1"}], + ) + + def test_clear_partial_member_list_deletes_permissions_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member.id, + ) + + DatasetPermissionService.clear_partial_member_list(dataset.id) + + remaining = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all() + assert remaining == [] diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index a814466e14f..ac0483a45d7 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -1,13 +1,21 @@ +import json from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.dataset import Dataset, ExternalKnowledgeBindings +from graphon.model_runtime.entities.model_entities import ModelType +from models.account import ( + Account, + AccountStatus, + Tenant, + TenantAccountJoin, + TenantAccountRole, + TenantStatus, +) +from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings from models.enums import DataSourceType from services.dataset_service import DatasetService from services.errors.account import NoPermissionError @@ -25,12 +33,12 @@ class DatasetUpdateTestDataFactory: email=f"{uuid4()}@example.com", name=f"user-{uuid4()}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() - tenant = Tenant(name=f"tenant-{account.id}", status="normal") + tenant = Tenant(name=f"tenant-{account.id}", status=TenantStatus.NORMAL) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -103,6 +111,34 @@ class DatasetUpdateTestDataFactory: db_session_with_containers.commit() return binding + @staticmethod + def create_external_knowledge_api( + db_session_with_containers: Session, + tenant_id: str, + created_by: str, + api_id: str | None = None, + name: str = "test-api", + ) -> ExternalKnowledgeApis: + """Create a real external knowledge API template for tenant-scoped update validation.""" + external_api = ExternalKnowledgeApis( + tenant_id=tenant_id, + created_by=created_by, + updated_by=created_by, + name=name, + description="test description", + settings=json.dumps( + { + "endpoint": "https://example.com", + "api_key": "test-api-key", + } + ), + ) + if api_id is not None: + external_api.id = api_id + db_session_with_containers.add(external_api) + db_session_with_containers.commit() + return external_api + class TestDatasetServiceUpdateDataset: """ @@ -138,6 +174,11 @@ class TestDatasetServiceUpdateDataset: ) binding_id = binding.id db_session_with_containers.expunge(binding) + external_api = DatasetUpdateTestDataFactory.create_external_knowledge_api( + db_session_with_containers, + tenant_id=tenant.id, + created_by=user.id, + ) update_data = { "name": "new_name", @@ -145,7 +186,7 @@ class TestDatasetServiceUpdateDataset: "external_retrieval_model": "new_model", "permission": "only_me", "external_knowledge_id": "new_knowledge_id", - "external_knowledge_api_id": str(uuid4()), + "external_knowledge_api_id": external_api.id, } result = DatasetService.update_dataset(dataset.id, update_data, user) @@ -218,11 +259,16 @@ class TestDatasetServiceUpdateDataset: created_by=user.id, provider="external", ) + external_api = DatasetUpdateTestDataFactory.create_external_knowledge_api( + db_session_with_containers, + tenant_id=tenant.id, + created_by=user.id, + ) update_data = { "name": "new_name", "external_knowledge_id": "knowledge_id", - "external_knowledge_api_id": str(uuid4()), + "external_knowledge_api_id": external_api.id, } with pytest.raises(ValueError) as context: diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index c8f04e92159..fe426ae5161 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -5,9 +5,9 @@ Testcontainers integration tests for archived workflow run deletion service. from datetime import UTC, datetime, timedelta from uuid import uuid4 -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import select +from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py index b3e7dd2a59f..315936d7212 100644 --- a/api/tests/test_containers_integration_tests/services/test_feature_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -274,6 +274,7 @@ class TestFeatureService: mock_config.ENABLE_EMAIL_CODE_LOGIN = True mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ENABLE_COLLABORATION_MODE = True mock_config.ALLOW_REGISTER = False mock_config.ALLOW_CREATE_WORKSPACE = False mock_config.MAIL_TYPE = "smtp" @@ -298,6 +299,7 @@ class TestFeatureService: # Verify authentication settings assert result.enable_email_code_login is True assert result.enable_email_password_login is False + assert result.enable_collaboration_mode is True assert result.is_allow_register is False assert result.is_allow_create_workspace is False @@ -401,6 +403,7 @@ class TestFeatureService: mock_config.ENABLE_EMAIL_CODE_LOGIN = True mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ENABLE_COLLABORATION_MODE = False mock_config.ALLOW_REGISTER = True mock_config.ALLOW_CREATE_WORKSPACE = True mock_config.MAIL_TYPE = "smtp" @@ -422,6 +425,7 @@ class TestFeatureService: assert result.enable_email_code_login is True assert result.enable_email_password_login is True assert result.enable_social_oauth_login is False + assert result.enable_collaboration_mode is False assert result.is_allow_register is True assert result.is_allow_create_workspace is True assert result.is_email_setup is True diff --git a/api/tests/unit_tests/services/test_hit_testing_service.py b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py similarity index 51% rename from api/tests/unit_tests/services/test_hit_testing_service.py rename to api/tests/test_containers_integration_tests/services/test_hit_testing_service.py index 80e9729f5bc..f332ba05ec2 100644 --- a/api/tests/unit_tests/services/test_hit_testing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py @@ -1,239 +1,193 @@ +from __future__ import annotations + import json from typing import Any, cast from unittest.mock import ANY, MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy import func, select +from sqlalchemy.orm import Session from core.rag.models.document import Document -from models.dataset import Dataset +from models.dataset import Dataset, DatasetQuery from services.hit_testing_service import HitTestingService -class TestHitTestingService: - """Test suite for HitTestingService""" +def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset: + tenant_id = str(uuid4()) + created_by = str(uuid4()) + ds = Dataset( + tenant_id=kwargs.get("tenant_id", tenant_id), + name=kwargs.get("name", "test-dataset"), + created_by=kwargs.get("created_by", created_by), + provider=provider, + ) + db_session.add(ds) + db_session.commit() + db_session.refresh(ds) + return ds - # ===== Utility Method Tests ===== + +class TestHitTestingService: + # ── Utility methods (pure logic, no DB) ──────────────────────────── def test_escape_query_for_search_should_escape_double_quotes(self): - """Test that escape_query_for_search escapes double quotes correctly""" - # Arrange query = 'test "query" with quotes' - expected = 'test \\"query\\" with quotes' - - # Act result = HitTestingService.escape_query_for_search(query) - - # Assert - assert result == expected + assert result == 'test \\"query\\" with quotes' def test_hit_testing_args_check_should_pass_with_valid_query(self): - """Test that hit_testing_args_check passes with a valid query""" - # Arrange - args = {"query": "valid query"} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"query": "valid query"}) def test_hit_testing_args_check_should_pass_with_valid_attachments(self): - """Test that hit_testing_args_check passes with valid attachment_ids""" - # Arrange - args = {"attachment_ids": ["id1", "id2"]} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]}) def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self): - """Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing""" - # Arrange - args = {} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query or attachment_ids is required" in str(exc_info.value) + with pytest.raises(ValueError, match="Query or attachment_ids is required"): + HitTestingService.hit_testing_args_check({}) def test_hit_testing_args_check_should_raise_error_when_query_too_long(self): - """Test that hit_testing_args_check raises ValueError if query exceeds 250 characters""" - # Arrange - args = {"query": "a" * 251} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query cannot exceed 250 characters" in str(exc_info.value) + with pytest.raises(ValueError, match="Query cannot exceed 250 characters"): + HitTestingService.hit_testing_args_check({"query": "a" * 251}) def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self): - """Test that hit_testing_args_check raises ValueError if attachment_ids is not a list""" - # Arrange - args = {"attachment_ids": "not a list"} + with pytest.raises(ValueError, match="Attachment_ids must be a list"): + HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"}) - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Attachment_ids must be a list" in str(exc_info.value) - - # ===== Response Formatting Tests ===== + # ── Response formatting ──────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") def test_compact_retrieve_response_should_format_correctly(self, mock_format): - """Test that compact_retrieve_response formats the response correctly""" - # Arrange query = "test query" mock_doc = MagicMock(spec=Document) - documents = [mock_doc] mock_record = MagicMock() mock_record.model_dump.return_value = {"content": "formatted content"} mock_format.return_value = [mock_record] - # Act - result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents)) + result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc])) - # Assert assert cast(dict[str, Any], result["query"])["content"] == query assert len(result["records"]) == 1 assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content" - mock_format.assert_called_once_with(documents) + mock_format.assert_called_once_with([mock_doc]) - def test_compact_external_retrieve_response_should_return_records_for_external_provider(self): - """Test that compact_external_retrieve_response returns records when dataset provider is external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "external" - query = "test query" + def test_compact_external_retrieve_response_should_return_records_for_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") documents = [ {"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}}, {"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}}, ] - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents) + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert len(result["records"]) == 2 assert cast(dict[str, Any], result["records"][0])["content"] == "c1" assert cast(dict[str, Any], result["records"][1])["title"] == "t2" - def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self): - """Test that compact_external_retrieve_response returns empty records for non-external provider""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" - documents = [{"content": "c1"}] + def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="vendor") - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], + HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]), + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== External Retrieve Tests ===== + # ── External retrieve (real DB) ──────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve): - """Test that external_retrieve successfully retrieves from external provider and commits query""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - dataset.provider = "external" - query = 'test "query"' + def test_external_retrieve_should_succeed_for_external_provider( + self, mock_ext_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") + account_id = str(uuid4()) account = MagicMock() - account.id = "account_id" - + account.id = account_id mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.external_retrieve( dataset=dataset, - query=query, + query='test "query"', account=account, external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == 'test "query"' assert cast(dict[str, Any], result["records"][0])["content"] == "ext content" - - # Verify call to RetrievalService.external_retrieve with escaped query mock_ext_retrieve.assert_called_once_with( - dataset_id="dataset_id", + dataset_id=dataset.id, query='test \\"query\\"', external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ) - # Verify DatasetQuery record was added and committed - mock_add.assert_called_once() - mock_commit.assert_called_once() + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 - def test_external_retrieve_should_return_empty_for_non_external_provider(self): - """Test that external_retrieve returns empty results immediately if provider is not external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" + def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers, provider="vendor") account = MagicMock() - # Act - result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account)) + result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account)) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== Retrieve Tests ===== + # ── Retrieve (real DB) ───────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve uses default model when retrieval_model is not provided""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" + def test_retrieve_should_use_default_model_when_none_provided( + self, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) dataset.retrieval_model = None - query = "test query" account = MagicMock() - account.id = "account_id" - + account.id = str(uuid4()) mock_retrieve.return_value = [] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={} + dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={} ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" mock_retrieve.assert_called_once() - # Verify top_k from default_retrieval_model (4) assert mock_retrieve.call_args.kwargs["top_k"] == 4 - mock_commit.assert_called_once() + + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve): - """Test that retrieve correctly calls metadata filtering when conditions are present""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_metadata_filtering( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "semantic_search", @@ -242,29 +196,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response - mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string") + mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string") mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_get_meta.assert_called_once() mock_retrieve.assert_called_once() assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"] @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve): - """Test that retrieve returns empty response if metadata filtering returns condition but no document IDs""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_return_empty_if_metadata_filtering_fails( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() retrieval_model = { @@ -274,37 +226,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response: condition returned but no IDs mock_get_meta.return_value = ({}, "condition_string") - # Act result = cast( dict[str, Any], HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, ), ) - # Assert assert result["records"] == [] mock_retrieve.assert_not_called() @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve handles attachment_ids and adds them to DatasetQuery""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) attachment_ids = ["att1", "att2"] retrieval_model = { @@ -315,21 +257,19 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, attachment_ids=attachment_ids, ) - # Assert mock_retrieve.assert_called_once_with( retrieval_method=ANY, - dataset_id="dataset_id", - query=query, + dataset_id=dataset.id, + query="test query", attachment_ids=attachment_ids, top_k=4, score_threshold=0.0, @@ -338,26 +278,27 @@ class TestHitTestingService: weights=None, document_ids_filter=None, ) - # Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images) - # The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}]) - called_query = mock_add.call_args[0][0] - query_content = json.loads(called_query.content) + + # Verify DatasetQuery was persisted with correct content structure + db_session_with_containers.expire_all() + latest = db_session_with_containers.scalar( + select(DatasetQuery) + .where(DatasetQuery.dataset_id == dataset.id) + .order_by(DatasetQuery.created_at.desc()) + .limit(1) + ) + assert latest is not None + query_content = json.loads(latest.content) assert len(query_content) == 3 # 1 text + 2 images assert query_content[0]["content_type"] == "text_query" assert query_content[1]["content_type"] == "image_query" assert query_content[1]["content"] == "att1" @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve passes reranking and threshold parameters correctly""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "hybrid_search", @@ -371,12 +312,14 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_retrieve.assert_called_once() kwargs = mock_retrieve.call_args.kwargs assert kwargs["score_threshold"] == 0.5 diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index c46b8fba0bd..80f9083e815 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -3,15 +3,15 @@ import uuid from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, ) +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode from models.workflow import Workflow, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index 0f252515f72..ed75363f3ba 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -5,17 +5,17 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from graphon.runtime import VariablePool from sqlalchemy.engine import Engine from configs import dify_config -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, MemberRecipient, ) +from graphon.runtime import VariablePool from models.account import Account, TenantAccountJoin from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 2340dd2a03d..cd63d3ad6c2 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -8,11 +8,11 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from graphon.file import FileType from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client +from graphon.file import FileType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ( ConversationFromSource, diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index ba926bf6758..8955a3b5f23 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -405,11 +405,10 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock models + from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.provider_entities import ProviderEntity - from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity - # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( ProviderEntity( @@ -644,9 +643,8 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock default model response - from graphon.model_runtime.entities.common_entities import I18nObject - from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity + from graphon.model_runtime.entities.common_entities import I18nObject mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", diff --git a/api/tests/test_containers_integration_tests/services/test_ops_service.py b/api/tests/test_containers_integration_tests/services/test_ops_service.py new file mode 100644 index 00000000000..e2e1a228b28 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_ops_service.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import uuid +from unittest.mock import patch + +import pytest +from faker import Faker +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.ops.entities.config_entity import TracingProviderEnum +from models.model import TraceAppConfig +from services.account_service import AccountService, TenantService +from services.app_service import AppService +from services.ops_service import OpsService +from tests.test_containers_integration_tests.helpers import generate_valid_password + + +class TestOpsService: + @pytest.fixture + def mock_external_service_dependencies(self): + with ( + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + yield { + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + } + + @pytest.fixture + def mock_ops_trace_manager(self): + with patch("services.ops_service.OpsTraceManager") as mock: + yield mock + + def _create_app(self, db_session_with_containers: Session, mock_external_service_dependencies): + fake = Faker() + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + }, + account, + ) + return app, account + + _SENTINEL = object() + + def _insert_trace_config( + self, + db_session: Session, + app_id: str, + provider: str, + tracing_config: dict | None | object = _SENTINEL, + ) -> TraceAppConfig: + trace_config = TraceAppConfig( + app_id=app_id, + tracing_provider=provider, + tracing_config=tracing_config if tracing_config is not self._SENTINEL else {"some": "config"}, + ) + db_session.add(trace_config) + db_session.commit() + return trace_config + + # ── get_tracing_app_config ───────────────────────────────────────── + + def test_get_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager): + result = OpsService.get_tracing_app_config(str(uuid.uuid4()), "arize") + assert result is None + + def test_get_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + fake_app_id = str(uuid.uuid4()) + self._insert_trace_config(db_session_with_containers, fake_app_id, "arize") + result = OpsService.get_tracing_app_config(fake_app_id, "arize") + assert result is None + + def test_get_tracing_app_config_none_config( + self, db_session_with_containers: Session, mock_external_service_dependencies, mock_ops_trace_manager + ): + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "arize", tracing_config=None) + + with pytest.raises(ValueError, match="Tracing config cannot be None."): + OpsService.get_tracing_app_config(app.id, "arize") + + @pytest.mark.parametrize( + ("provider", "default_url"), + [ + ("arize", "https://app.arize.com/"), + ("phoenix", "https://app.phoenix.arize.com/projects/"), + ("langsmith", "https://smith.langchain.com/"), + ("opik", "https://www.comet.com/opik/"), + ("weave", "https://wandb.ai/"), + ("aliyun", "https://arms.console.aliyun.com/"), + ("tencent", "https://console.cloud.tencent.com/apm"), + ("mlflow", "http://localhost:5000/"), + ("databricks", "https://www.databricks.com/"), + ], + ) + def test_get_tracing_app_config_providers_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider, default_url + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.obfuscated_decrypt_token.return_value = {} + mock_otm.get_trace_config_project_url.side_effect = Exception("error") + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, provider) + + result = OpsService.get_tracing_app_config(app.id, provider) + + assert result is not None + assert result["tracing_config"]["project_url"] == default_url + + @pytest.mark.parametrize( + "provider", + ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"], + ) + def test_get_tracing_app_config_providers_success( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.obfuscated_decrypt_token.return_value = {"project_url": "success_url"} + mock_otm.get_trace_config_project_url.return_value = "success_url" + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, provider) + + result = OpsService.get_tracing_app_config(app.id, provider) + + assert result is not None + assert result["tracing_config"]["project_url"] == "success_url" + + def test_get_tracing_app_config_langfuse_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_otm.get_trace_config_project_key.return_value = "key" + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "langfuse") + + result = OpsService.get_tracing_app_config(app.id, "langfuse") + + assert result is not None + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key" + + def test_get_tracing_app_config_langfuse_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "langfuse") + + result = OpsService.get_tracing_app_config(app.id, "langfuse") + + assert result is not None + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/" + + # ── create_tracing_app_config ────────────────────────────────────── + + def test_create_tracing_app_config_invalid_provider(self, db_session_with_containers: Session): + result = OpsService.create_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {}) + assert result == {"error": "Invalid tracing provider: invalid_provider"} + + def test_create_tracing_app_config_invalid_credentials( + self, db_session_with_containers: Session, mock_ops_trace_manager + ): + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + result = OpsService.create_tracing_app_config( + str(uuid.uuid4()), TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"} + ) + assert result == {"error": "Invalid Credentials"} + + @pytest.mark.parametrize( + ("provider", "config"), + [ + (TracingProviderEnum.ARIZE, {}), + (TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}), + (TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}), + (TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}), + ], + ) + def test_create_tracing_app_config_project_url_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider, config + ): + # Existing config causes the service to return None before reaching the DB insert + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.side_effect = Exception("error") + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(provider)) + + result = OpsService.create_tracing_app_config(app.id, provider, config) + + assert result is None + + def test_create_tracing_app_config_langfuse_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_key.return_value = "key" + mock_otm.encrypt_tracing_config.return_value = {} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config( + app.id, + TracingProviderEnum.LANGFUSE, + {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}, + ) + + assert result == {"result": "success"} + + def test_create_tracing_app_config_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result is None + + def test_create_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + result = OpsService.create_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_create_tracing_app_config_with_empty_other_keys( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + # "project" is in other_keys for Arize; providing "" triggers default substitution + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.side_effect = Exception("no url") + mock_otm.encrypt_tracing_config.return_value = {} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {"project": ""}) + + assert result == {"result": "success"} + + def test_create_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.return_value = "http://project_url" + mock_otm.encrypt_tracing_config.return_value = {"encrypted": "config"} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result == {"result": "success"} + + # ── update_tracing_app_config ────────────────────────────────────── + + def test_update_tracing_app_config_invalid_provider(self, db_session_with_containers: Session): + with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"): + OpsService.update_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {}) + + def test_update_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager): + result = OpsService.update_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_update_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + fake_app_id = str(uuid.uuid4()) + self._insert_trace_config(db_session_with_containers, fake_app_id, str(TracingProviderEnum.ARIZE)) + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + result = OpsService.update_tracing_app_config(fake_app_id, TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_update_tracing_app_config_invalid_credentials( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.encrypt_tracing_config.return_value = {} + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.check_trace_config_is_effective.return_value = False + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + with pytest.raises(ValueError, match="Invalid Credentials"): + OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + def test_update_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.encrypt_tracing_config.return_value = {"updated": "config"} + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.check_trace_config_is_effective.return_value = True + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + result = OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result is not None + assert result["app_id"] == app.id + + # ── delete_tracing_app_config ────────────────────────────────────── + + def test_delete_tracing_app_config_no_config(self, db_session_with_containers: Session): + result = OpsService.delete_tracing_app_config(str(uuid.uuid4()), "arize") + assert result is None + + def test_delete_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "arize") + + result = OpsService.delete_tracing_app_config(app.id, "arize") + + assert result is True + remaining = db_session_with_containers.scalar( + select(TraceAppConfig) + .where(TraceAppConfig.app_id == app.id, TraceAppConfig.tracing_provider == "arize") + .limit(1) + ) + assert remaining is None diff --git a/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py b/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py new file mode 100644 index 00000000000..ccc4188dbf0 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +import uuid +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models.model import AccountTrialAppRecord, TrialApp +from services import recommended_app_service as service_module +from services.recommended_app_service import RecommendedAppService + +# ── Helpers ──────────────────────────────────────────────────────────── + + +def _apps_response( + recommended_apps: list[dict] | None = None, + categories: list[str] | None = None, +) -> dict: + if recommended_apps is None: + recommended_apps = [ + {"id": "app-1", "name": "Test App 1", "description": "d1", "category": "productivity"}, + {"id": "app-2", "name": "Test App 2", "description": "d2", "category": "communication"}, + ] + if categories is None: + categories = ["productivity", "communication", "utilities"] + return {"recommended_apps": recommended_apps, "categories": categories} + + +def _app_detail( + app_id: str = "app-123", + name: str = "Test App", + description: str = "Test description", + **kwargs: Any, +) -> dict: + detail: dict[str, Any] = { + "id": app_id, + "name": name, + "description": description, + "category": kwargs.get("category", "productivity"), + "icon": kwargs.get("icon", "🚀"), + "model_config": kwargs.get("model_config", {}), + } + detail.update(kwargs) + return detail + + +def _recommendation_detail(result: dict[str, Any] | None) -> dict[str, Any] | None: + return cast("dict[str, Any] | None", result) + + +def _mock_factory_for_apps( + monkeypatch: pytest.MonkeyPatch, + *, + mode: str, + result: dict[str, Any], + fallback_result: dict[str, Any] | None = None, +) -> tuple[MagicMock, MagicMock]: + retrieval_instance = MagicMock() + retrieval_instance.get_recommended_apps_and_categories.return_value = result + retrieval_factory = MagicMock(return_value=retrieval_instance) + monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", mode, raising=False) + monkeypatch.setattr( + service_module.RecommendAppRetrievalFactory, + "get_recommend_app_factory", + MagicMock(return_value=retrieval_factory), + ) + builtin_instance = MagicMock() + if fallback_result is not None: + builtin_instance.fetch_recommended_apps_from_builtin.return_value = fallback_result + monkeypatch.setattr( + service_module.RecommendAppRetrievalFactory, + "get_buildin_recommend_app_retrieval", + MagicMock(return_value=builtin_instance), + ) + return retrieval_instance, builtin_instance + + +# ── Pure logic tests: get_recommended_apps_and_categories ────────────── + + +class TestRecommendedAppServiceGetApps: + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_success_with_apps(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + expected = _apps_response() + + mock_instance = MagicMock() + mock_instance.get_recommended_apps_and_categories.return_value = expected + mock_factory = MagicMock(return_value=mock_instance) + mock_factory_class.get_recommend_app_factory.return_value = mock_factory + + result = RecommendedAppService.get_recommended_apps_and_categories("en-US") + + assert result == expected + assert len(result["recommended_apps"]) == 2 + assert len(result["categories"]) == 3 + mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote") + mock_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US") + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + empty_response = {"recommended_apps": [], "categories": []} + builtin_response = _apps_response( + recommended_apps=[{"id": "builtin-1", "name": "Builtin App", "category": "default"}] + ) + + mock_remote_instance = MagicMock() + mock_remote_instance.get_recommended_apps_and_categories.return_value = empty_response + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_remote_instance) + + mock_builtin_instance = MagicMock() + mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response + mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance + + result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN") + + assert result == builtin_response + assert result["recommended_apps"][0]["id"] == "builtin-1" + mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US") + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db" + none_response = {"recommended_apps": None, "categories": ["test"]} + builtin_response = _apps_response() + + mock_db_instance = MagicMock() + mock_db_instance.get_recommended_apps_and_categories.return_value = none_response + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_db_instance) + + mock_builtin_instance = MagicMock() + mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response + mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance + + result = RecommendedAppService.get_recommended_apps_and_categories("en-US") + + assert result == builtin_response + mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once() + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_different_languages(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin" + + for language in ["en-US", "zh-CN", "ja-JP", "fr-FR"]: + lang_response = _apps_response( + recommended_apps=[{"id": f"app-{language}", "name": f"App {language}", "category": "test"}] + ) + mock_instance = MagicMock() + mock_instance.get_recommended_apps_and_categories.return_value = lang_response + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = RecommendedAppService.get_recommended_apps_and_categories(language) + + assert result["recommended_apps"][0]["id"] == f"app-{language}" + mock_instance.get_recommended_apps_and_categories.assert_called_with(language) + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_uses_correct_factory_mode(self, mock_config, mock_factory_class): + for mode in ["remote", "builtin", "db"]: + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode + response = _apps_response() + mock_instance = MagicMock() + mock_instance.get_recommended_apps_and_categories.return_value = response + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + RecommendedAppService.get_recommended_apps_and_categories("en-US") + + mock_factory_class.get_recommend_app_factory.assert_called_with(mode) + + +# ── Pure logic tests: get_recommend_app_detail ───────────────────────── + + +class TestRecommendedAppServiceGetDetail: + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_success(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + expected = _app_detail(app_id="app-123", name="Productivity App", description="A great app") + + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = expected + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("app-123")) + + assert result == expected + assert result["id"] == "app-123" + mock_instance.get_recommend_app_detail.assert_called_once_with("app-123") + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_different_modes(self, mock_config, mock_factory_class): + for mode in ["remote", "builtin", "db"]: + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode + detail = _app_detail(app_id="test-app", name=f"App from {mode}") + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = detail + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("test-app")) + + assert result["name"] == f"App from {mode}" + mock_factory_class.get_recommend_app_factory.assert_called_with(mode) + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_returns_none_when_not_found(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = None + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("nonexistent")) + + assert result is None + mock_instance.get_recommend_app_detail.assert_called_once_with("nonexistent") + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_returns_empty_dict(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin" + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = {} + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("app-empty")) + + assert result == {} + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_complex_model_config(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + complex_config = { + "provider": "openai", + "model": "gpt-4", + "parameters": {"temperature": 0.7, "max_tokens": 2000, "top_p": 1.0}, + } + expected = _app_detail( + app_id="complex-app", + name="Complex App", + model_config=complex_config, + workflows=["workflow-1", "workflow-2"], + tools=["tool-1", "tool-2", "tool-3"], + ) + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = expected + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("complex-app")) + + assert result["model_config"] == complex_config + assert len(result["workflows"]) == 2 + assert len(result["tools"]) == 3 + + +# ── Integration tests: trial app features (real DB) ──────────────────── + + +class TestRecommendedAppServiceTrialFeatures: + def test_get_apps_should_not_query_trial_table_when_disabled( + self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch + ): + expected = {"recommended_apps": [{"app_id": "app-1"}], "categories": ["all"]} + retrieval_instance, builtin_instance = _mock_factory_for_apps(monkeypatch, mode="remote", result=expected) + monkeypatch.setattr( + service_module.FeatureService, + "get_system_features", + MagicMock(return_value=SimpleNamespace(enable_trial_app=False)), + ) + + result = RecommendedAppService.get_recommended_apps_and_categories("en-US") + + assert result == expected + retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US") + builtin_instance.fetch_recommended_apps_from_builtin.assert_not_called() + + def test_get_apps_should_enrich_can_trial_when_enabled( + self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch + ): + app_id_1 = str(uuid.uuid4()) + app_id_2 = str(uuid.uuid4()) + tenant_id = str(uuid.uuid4()) + + # app_id_1 has a TrialApp record; app_id_2 does not + db_session_with_containers.add(TrialApp(app_id=app_id_1, tenant_id=tenant_id)) + db_session_with_containers.commit() + + remote_result = {"recommended_apps": [], "categories": []} + fallback_result = { + "recommended_apps": [{"app_id": app_id_1}, {"app_id": app_id_2}], + "categories": ["all"], + } + _, builtin_instance = _mock_factory_for_apps( + monkeypatch, mode="remote", result=remote_result, fallback_result=fallback_result + ) + monkeypatch.setattr( + service_module.FeatureService, + "get_system_features", + MagicMock(return_value=SimpleNamespace(enable_trial_app=True)), + ) + + result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP") + + builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US") + assert result["recommended_apps"][0]["can_trial"] is True + assert result["recommended_apps"][1]["can_trial"] is False + + @pytest.mark.parametrize("has_trial_app", [True, False]) + def test_get_detail_should_set_can_trial_when_enabled( + self, + db_session_with_containers: Session, + monkeypatch: pytest.MonkeyPatch, + has_trial_app: bool, + ): + app_id = str(uuid.uuid4()) + tenant_id = str(uuid.uuid4()) + + if has_trial_app: + db_session_with_containers.add(TrialApp(app_id=app_id, tenant_id=tenant_id)) + db_session_with_containers.commit() + + detail = {"id": app_id, "name": "Test App"} + retrieval_instance = MagicMock() + retrieval_instance.get_recommend_app_detail.return_value = detail + retrieval_factory = MagicMock(return_value=retrieval_instance) + monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", "remote", raising=False) + monkeypatch.setattr( + service_module.RecommendAppRetrievalFactory, + "get_recommend_app_factory", + MagicMock(return_value=retrieval_factory), + ) + monkeypatch.setattr( + service_module.FeatureService, + "get_system_features", + MagicMock(return_value=SimpleNamespace(enable_trial_app=True)), + ) + + result = cast(dict[str, Any], RecommendedAppService.get_recommend_app_detail(app_id)) + + assert result["id"] == app_id + assert result["can_trial"] is has_trial_app + + def test_add_trial_app_record_increments_count_for_existing(self, db_session_with_containers: Session): + app_id = str(uuid.uuid4()) + account_id = str(uuid.uuid4()) + + db_session_with_containers.add(AccountTrialAppRecord(app_id=app_id, account_id=account_id, count=3)) + db_session_with_containers.commit() + + RecommendedAppService.add_trial_app_record(app_id, account_id) + + db_session_with_containers.expire_all() + record = db_session_with_containers.scalar( + select(AccountTrialAppRecord) + .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id) + .limit(1) + ) + assert record is not None + assert record.count == 4 + + def test_add_trial_app_record_creates_new_record(self, db_session_with_containers: Session): + app_id = str(uuid.uuid4()) + account_id = str(uuid.uuid4()) + + RecommendedAppService.add_trial_app_record(app_id, account_id) + + db_session_with_containers.expire_all() + record = db_session_with_containers.scalar( + select(AccountTrialAppRecord) + .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id) + .limit(1) + ) + assert record is not None + assert record.app_id == app_id + assert record.account_id == account_id + assert record.count == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_schedule_service.py b/api/tests/test_containers_integration_tests/services/test_schedule_service.py new file mode 100644 index 00000000000..87f33062582 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_schedule_service.py @@ -0,0 +1,387 @@ +"""Testcontainers integration tests for schedule service SQL-backed behavior.""" + +from datetime import datetime +from types import SimpleNamespace +from uuid import uuid4 + +import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate +from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError +from events.event_handlers.sync_workflow_schedule_when_app_published import sync_schedule_from_workflow +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.trigger import WorkflowSchedulePlan +from services.errors.account import AccountNotFoundError +from services.trigger.schedule_service import ScheduleService + + +class ScheduleServiceIntegrationFactory: + @staticmethod + def create_account_with_tenant( + db_session_with_containers: Session, + role: TenantAccountRole = TenantAccountRole.OWNER, + ) -> tuple[Account, Tenant]: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_schedule_plan( + db_session_with_containers: Session, + *, + tenant_id: str, + app_id: str | None = None, + node_id: str = "start", + cron_expression: str = "30 10 * * *", + timezone: str = "UTC", + next_run_at: datetime | None = None, + ) -> WorkflowSchedulePlan: + schedule = WorkflowSchedulePlan( + tenant_id=tenant_id, + app_id=app_id or str(uuid4()), + node_id=node_id, + cron_expression=cron_expression, + timezone=timezone, + next_run_at=next_run_at, + ) + db_session_with_containers.add(schedule) + db_session_with_containers.commit() + return schedule + + +def _cron_workflow( + *, + node_id: str = "start", + cron_expression: str = "30 10 * * *", + timezone: str = "UTC", +): + return SimpleNamespace( + graph_dict={ + "nodes": [ + { + "id": node_id, + "data": { + "type": "trigger-schedule", + "mode": "cron", + "cron_expression": cron_expression, + "timezone": timezone, + }, + } + ] + } + ) + + +def _no_schedule_workflow(): + return SimpleNamespace( + graph_dict={ + "nodes": [ + { + "id": "node-1", + "data": {"type": "llm"}, + } + ] + } + ) + + +class TestScheduleServiceIntegration: + def test_create_schedule_persists_schedule(self, db_session_with_containers: Session): + account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + expected_next_run = datetime(2026, 1, 1, 10, 30, 0) + config = ScheduleConfig( + node_id="start", + cron_expression="30 10 * * *", + timezone="UTC", + ) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + schedule = ScheduleService.create_schedule( + session=db_session_with_containers, + tenant_id=tenant.id, + app_id=str(uuid4()), + config=config, + ) + + persisted = db_session_with_containers.get(WorkflowSchedulePlan, schedule.id) + assert persisted is not None + assert persisted.tenant_id == tenant.id + assert persisted.node_id == "start" + assert persisted.cron_expression == "30 10 * * *" + assert persisted.timezone == "UTC" + assert persisted.next_run_at == expected_next_run + + def test_update_schedule_updates_fields_and_recomputes_next_run(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + cron_expression="30 10 * * *", + timezone="UTC", + ) + expected_next_run = datetime(2026, 1, 2, 12, 0, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + updated = ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + updates=SchedulePlanUpdate( + cron_expression="0 12 * * *", + timezone="America/New_York", + ), + ) + + db_session_with_containers.refresh(updated) + assert updated.cron_expression == "0 12 * * *" + assert updated.timezone == "America/New_York" + assert updated.next_run_at == expected_next_run + + def test_update_schedule_updates_only_node_id_without_recomputing_time(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + initial_next_run = datetime(2026, 1, 1, 10, 0, 0) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + next_run_at=initial_next_run, + ) + + with pytest.MonkeyPatch.context() as monkeypatch: + calls: list[tuple] = [] + + def _track(*args, **kwargs): + calls.append((args, kwargs)) + return datetime(2026, 1, 9, 10, 0, 0) + + monkeypatch.setattr("services.trigger.schedule_service.calculate_next_run_at", _track) + updated = ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + updates=SchedulePlanUpdate(node_id="node-new"), + ) + + db_session_with_containers.refresh(updated) + assert updated.node_id == "node-new" + assert updated.next_run_at == initial_next_run + assert calls == [] + + def test_update_schedule_not_found_raises(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=str(uuid4()), + updates=SchedulePlanUpdate(node_id="node-new"), + ) + + def test_delete_schedule_removes_row(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + ) + + ScheduleService.delete_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + ) + db_session_with_containers.commit() + + assert db_session_with_containers.get(WorkflowSchedulePlan, schedule.id) is None + + def test_delete_schedule_not_found_raises(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.delete_schedule( + session=db_session_with_containers, + schedule_id=str(uuid4()), + ) + + def test_get_tenant_owner_returns_owner_account(self, db_session_with_containers: Session): + owner, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.OWNER, + ) + + result = ScheduleService.get_tenant_owner( + session=db_session_with_containers, + tenant_id=tenant.id, + ) + + assert result.id == owner.id + + def test_get_tenant_owner_falls_back_to_admin(self, db_session_with_containers: Session): + admin, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.ADMIN, + ) + + result = ScheduleService.get_tenant_owner( + session=db_session_with_containers, + tenant_id=tenant.id, + ) + + assert result.id == admin.id + + def test_get_tenant_owner_raises_when_account_record_missing(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + db_session_with_containers.execute(delete(TenantAccountJoin)) + missing_account_id = str(uuid4()) + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=missing_account_id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + with pytest.raises(AccountNotFoundError, match=missing_account_id): + ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id) + + def test_get_tenant_owner_raises_when_no_owner_or_admin_found(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.commit() + + with pytest.raises(AccountNotFoundError, match=tenant.id): + ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id) + + def test_update_next_run_at_updates_persisted_value(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + ) + expected_next_run = datetime(2026, 1, 3, 10, 30, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = ScheduleService.update_next_run_at( + session=db_session_with_containers, + schedule_id=schedule.id, + ) + + db_session_with_containers.refresh(schedule) + assert result == expected_next_run + assert schedule.next_run_at == expected_next_run + + def test_update_next_run_at_raises_when_schedule_not_found(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.update_next_run_at( + session=db_session_with_containers, + schedule_id=str(uuid4()), + ) + + +class TestSyncScheduleFromWorkflowIntegration: + def test_sync_schedule_create_new(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + expected_next_run = datetime(2026, 1, 4, 10, 30, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_cron_workflow(), + ) + + assert result is not None + persisted = db_session_with_containers.execute( + select(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app_id) + ).scalar_one() + assert persisted.node_id == "start" + assert persisted.cron_expression == "30 10 * * *" + assert persisted.timezone == "UTC" + assert persisted.next_run_at == expected_next_run + + def test_sync_schedule_update_existing(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + existing = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + app_id=app_id, + node_id="old-start", + cron_expression="30 10 * * *", + timezone="UTC", + ) + existing_id = existing.id + expected_next_run = datetime(2026, 1, 5, 12, 0, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_cron_workflow( + node_id="start", + cron_expression="0 12 * * *", + timezone="America/New_York", + ), + ) + + assert result is not None + db_session_with_containers.expire_all() + persisted = db_session_with_containers.get(WorkflowSchedulePlan, existing_id) + assert persisted is not None + assert persisted.node_id == "start" + assert persisted.cron_expression == "0 12 * * *" + assert persisted.timezone == "America/New_York" + assert persisted.next_run_at == expected_next_run + + def test_sync_schedule_remove_when_no_config(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + existing = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + app_id=app_id, + ) + existing_id = existing.id + + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_no_schedule_workflow(), + ) + + assert result is None + db_session_with_containers.expire_all() + assert db_session_with_containers.get(WorkflowSchedulePlan, existing_id) is None diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index f504f355890..5a6bf0466ef 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -12,7 +12,13 @@ from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset from models.enums import DataSourceType, TagType from models.model import App, Tag, TagBinding -from services.tag_service import TagService +from services.tag_service import ( + SaveTagPayload, + TagBindingCreatePayload, + TagBindingDeletePayload, + TagService, + UpdateTagPayload, +) class TestTagService: @@ -685,7 +691,7 @@ class TestTagService: db_session_with_containers, mock_external_service_dependencies ) - tag_args = {"name": "test_tag_name", "type": "knowledge"} + tag_args = SaveTagPayload(name="test_tag_name", type="knowledge") # Act: Execute the method under test result = TagService.save_tags(tag_args) @@ -725,7 +731,7 @@ class TestTagService: ) # Create first tag - tag_args = {"name": "duplicate_tag", "type": "app"} + tag_args = SaveTagPayload(name="duplicate_tag", type="app") TagService.save_tags(tag_args) # Act & Assert: Verify proper error handling @@ -749,11 +755,11 @@ class TestTagService: ) # Create a tag to update - tag_args = {"name": "original_name", "type": "knowledge"} + tag_args = SaveTagPayload(name="original_name", type="knowledge") tag = TagService.save_tags(tag_args) # Update args - update_args = {"name": "updated_name", "type": "knowledge"} + update_args = UpdateTagPayload(name="updated_name", type="knowledge") # Act: Execute the method under test result = TagService.update_tags(update_args, tag.id) @@ -793,7 +799,7 @@ class TestTagService: non_existent_tag_id = str(uuid.uuid4()) - update_args = {"name": "updated_name", "type": "knowledge"} + update_args = UpdateTagPayload(name="updated_name", type="knowledge") # Act & Assert: Verify proper error handling with pytest.raises(NotFound) as exc_info: @@ -817,14 +823,14 @@ class TestTagService: ) # Create two tags - tag1_args = {"name": "first_tag", "type": "app"} + tag1_args = SaveTagPayload(name="first_tag", type="app") tag1 = TagService.save_tags(tag1_args) - tag2_args = {"name": "second_tag", "type": "app"} + tag2_args = SaveTagPayload(name="second_tag", type="app") tag2 = TagService.save_tags(tag2_args) # Try to update second tag with first tag's name - update_args = {"name": "first_tag", "type": "app"} + update_args = UpdateTagPayload(name="first_tag", type="app") # Act & Assert: Verify proper error handling with pytest.raises(ValueError) as exc_info: @@ -988,8 +994,10 @@ class TestTagService: dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Act: Execute the method under test - binding_args = {"type": "knowledge", "target_id": dataset.id, "tag_ids": [tag.id for tag in tags]} - TagService.save_tag_binding(binding_args) + binding_payload = TagBindingCreatePayload( + type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags] + ) + TagService.save_tag_binding(binding_payload) # Assert: Verify the expected outcomes @@ -1030,11 +1038,11 @@ class TestTagService: app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Create first binding - binding_args = {"type": "app", "target_id": app.id, "tag_ids": [tag.id]} - TagService.save_tag_binding(binding_args) + binding_payload = TagBindingCreatePayload(type="app", target_id=app.id, tag_ids=[tag.id]) + TagService.save_tag_binding(binding_payload) # Act: Try to create duplicate binding - TagService.save_tag_binding(binding_args) + TagService.save_tag_binding(binding_payload) # Assert: Verify the expected outcomes @@ -1071,11 +1079,10 @@ class TestTagService: non_existent_target_id = str(uuid.uuid4()) # Act & Assert: Verify proper error handling - binding_args = {"type": "invalid_type", "target_id": non_existent_target_id, "tag_ids": [tag.id]} + from pydantic import ValidationError - with pytest.raises(NotFound) as exc_info: - TagService.save_tag_binding(binding_args) - assert "Invalid binding type" in str(exc_info.value) + with pytest.raises(ValidationError): + TagBindingCreatePayload(type="invalid_type", target_id=non_existent_target_id, tag_ids=[tag.id]) def test_delete_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -1113,8 +1120,8 @@ class TestTagService: assert binding_before is not None # Act: Execute the method under test - delete_args = {"type": "knowledge", "target_id": dataset.id, "tag_id": tag.id} - TagService.delete_tag_binding(delete_args) + delete_payload = TagBindingDeletePayload(type="knowledge", target_id=dataset.id, tag_id=tag.id) + TagService.delete_tag_binding(delete_payload) # Assert: Verify the expected outcomes # Verify tag binding was deleted @@ -1149,8 +1156,8 @@ class TestTagService: app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Act: Try to delete non-existent binding - delete_args = {"type": "app", "target_id": app.id, "tag_id": tag.id} - TagService.delete_tag_binding(delete_args) + delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_id=tag.id) + TagService.delete_tag_binding(delete_payload) # Assert: Verify the expected outcomes # No error should be raised, and database state should remain unchanged diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 4fe65d58032..7825f502f77 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -233,11 +233,10 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None - assert result.password is not None - assert result.password_salt is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None + assert refreshed.password is not None + assert refreshed.password_salt is not None def test_authenticate_account_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -414,9 +413,8 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None def test_get_user_through_email_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py new file mode 100644 index 00000000000..ec10c51e049 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py @@ -0,0 +1,507 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import AppTriggerStatus, AppTriggerType +from models.model import App +from models.trigger import AppTrigger, WorkflowWebhookTrigger +from models.workflow import Workflow +from services.errors.app import QuotaExceededError +from services.trigger.webhook_service import WebhookService + + +class WebhookServiceRelationshipFactory: + @staticmethod + def create_account_and_tenant(db_session_with_containers: Session) -> tuple[Account, Tenant]: + account = Account( + name=f"Account {uuid4()}", + email=f"webhook-{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status="normal") + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_app(db_session_with_containers: Session, tenant: Tenant, account: Account) -> App: + app = App( + tenant_id=tenant.id, + name=f"Webhook App {uuid4()}", + description="", + mode="workflow", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + return app + + @staticmethod + def create_workflow( + db_session_with_containers: Session, + *, + app: App, + account: Account, + node_ids: list[str], + version: str, + ) -> Workflow: + graph = { + "nodes": [ + { + "id": node_id, + "data": { + "type": TRIGGER_WEBHOOK_NODE_TYPE, + "title": f"Webhook {node_id}", + "method": "post", + "content_type": "application/json", + "headers": [], + "params": [], + "body": [], + "status_code": 200, + "response_body": '{"status": "ok"}', + "timeout": 30, + }, + } + for node_id in node_ids + ], + "edges": [], + } + + workflow = Workflow( + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + version=version, + ) + db_session_with_containers.add(workflow) + db_session_with_containers.commit() + return workflow + + @staticmethod + def create_webhook_trigger( + db_session_with_containers: Session, + *, + app: App, + account: Account, + node_id: str, + webhook_id: str | None = None, + ) -> WorkflowWebhookTrigger: + webhook_trigger = WorkflowWebhookTrigger( + app_id=app.id, + node_id=node_id, + tenant_id=app.tenant_id, + webhook_id=webhook_id or uuid4().hex[:24], + created_by=account.id, + ) + db_session_with_containers.add(webhook_trigger) + db_session_with_containers.commit() + return webhook_trigger + + @staticmethod + def create_app_trigger( + db_session_with_containers: Session, + *, + app: App, + node_id: str, + status: AppTriggerStatus, + ) -> AppTrigger: + app_trigger = AppTrigger( + tenant_id=app.tenant_id, + app_id=app.id, + node_id=node_id, + trigger_type=AppTriggerType.TRIGGER_WEBHOOK, + provider_name="webhook", + title=f"Webhook {node_id}", + status=status, + ) + db_session_with_containers.add(app_trigger) + db_session_with_containers.commit() + return app_trigger + + +class TestWebhookServiceLookupWithContainers: + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with pytest.raises(ValueError, match="App trigger not found"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.RATE_LIMITED + ) + + with pytest.raises(ValueError, match="rate limited"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.DISABLED + ) + + with pytest.raises(ValueError, match="disabled"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.ENABLED + ) + + with pytest.raises(ValueError, match="Workflow not found"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["published-node"], + version="2026-04-14.001", + ) + draft_workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["debug-node"], + version=Workflow.VERSION_DRAFT, + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="debug-node" + ) + + got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow( + webhook_trigger.webhook_id, + is_debug=True, + ) + + assert got_trigger.id == webhook_trigger.id + assert got_workflow.id == draft_workflow.id + assert got_node_config["id"] == "debug-node" + + +class TestWebhookServiceTriggerExecutionWithContainers: + def test_trigger_workflow_execution_triggers_async_workflow_successfully( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + end_user = SimpleNamespace(id=str(uuid4())) + webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"} + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + return_value=end_user, + ), + patch("services.trigger.webhook_service.QuotaType.TRIGGER.consume") as mock_consume, + patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger, + ): + WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow) + + mock_consume.assert_called_once_with(webhook_trigger.tenant_id) + mock_trigger.assert_called_once() + trigger_args = mock_trigger.call_args.args + assert trigger_args[1] is end_user + assert trigger_args[2].workflow_id == workflow.id + assert trigger_args[2].root_node_id == webhook_trigger.node_id + + def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + return_value=SimpleNamespace(id=str(uuid4())), + ), + patch( + "services.trigger.webhook_service.QuotaType.TRIGGER.consume", + side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1), + ), + patch( + "services.trigger.webhook_service.AppTriggerService.mark_tenant_triggers_rate_limited" + ) as mock_mark_rate_limited, + ): + with pytest.raises(QuotaExceededError): + WebhookService.trigger_workflow_execution( + webhook_trigger, + {"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}, + workflow, + ) + + mock_mark_rate_limited.assert_called_once_with(tenant.id) + + def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + side_effect=RuntimeError("boom"), + ), + patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception, + ): + with pytest.raises(RuntimeError, match="boom"): + WebhookService.trigger_workflow_execution( + webhook_trigger, + {"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}, + workflow, + ) + + mock_logger_exception.assert_called_once() + + +class TestWebhookServiceRelationshipSyncWithContainers: + def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + node_ids = [f"node-{index}" for index in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)] + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=node_ids, version=Workflow.VERSION_DRAFT + ) + + with pytest.raises(ValueError, match="maximum webhook node limit"): + WebhookService.sync_webhook_relationships(app, workflow) + + def test_sync_webhook_relationships_raises_when_lock_not_acquired( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version=Workflow.VERSION_DRAFT + ) + lock = MagicMock() + lock.acquire.return_value = False + + with patch("services.trigger.webhook_service.redis_client.lock", return_value=lock): + with pytest.raises(RuntimeError, match="Failed to acquire lock"): + WebhookService.sync_webhook_relationships(app, workflow) + + def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + stale_trigger = factory.create_webhook_trigger( + db_session_with_containers, + app=app, + account=account, + node_id="node-stale", + webhook_id="stale-webhook-id-000001", + ) + stale_trigger_id = stale_trigger.id + workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["node-new"], + version=Workflow.VERSION_DRAFT, + ) + + with patch( + "services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="new-webhook-id-000001" + ): + WebhookService.sync_webhook_relationships(app, workflow) + + db_session_with_containers.expire_all() + records = db_session_with_containers.scalars( + select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id) + ).all() + + assert [record.node_id for record in records] == ["node-new"] + assert records[0].webhook_id == "new-webhook-id-000001" + assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None + + def test_sync_webhook_relationships_sets_redis_cache_for_new_record( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["node-cache"], + version=Workflow.VERSION_DRAFT, + ) + cache_key = f"{WebhookService.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:node-cache" + + with patch( + "services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="cache-webhook-id-00001" + ): + WebhookService.sync_webhook_relationships(app, workflow) + + cached_payload = WebhookServiceRelationshipFactory._read_cache(cache_key) + assert cached_payload is not None + assert cached_payload["node_id"] == "node-cache" + assert cached_payload["webhook_id"] == "cache-webhook-id-00001" + + def test_sync_webhook_relationships_logs_when_lock_release_fails( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=[], version=Workflow.VERSION_DRAFT + ) + lock = MagicMock() + lock.acquire.return_value = True + lock.release.side_effect = RuntimeError("release failed") + + with ( + patch("services.trigger.webhook_service.redis_client.lock", return_value=lock), + patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception, + ): + WebhookService.sync_webhook_relationships(app, workflow) + + mock_logger_exception.assert_called_once() + + +def _read_cache(cache_key: str) -> dict[str, str] | None: + from extensions.ext_redis import redis_client + + cached = redis_client.get(cache_key) + if not cached: + return None + if isinstance(cached, bytes): + cached = cached.decode("utf-8") + return json.loads(cached) + + +WebhookServiceRelationshipFactory._read_cache = staticmethod(_read_cache) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 749c6fff5bc..1e57b5603de 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -8,9 +8,9 @@ from unittest.mock import patch import pytest from faker import Faker -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session +from graphon.enums import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLogCreatedFrom diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 0c281c8c33b..86cf2327c7f 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,9 +1,9 @@ import pytest from faker import Faker -from graphon.variables.segments import StringSegment from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from graphon.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index d3e765055a0..af83adaae01 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -1,3 +1,5 @@ +import inspect +import json from unittest.mock import patch import pytest @@ -6,6 +8,8 @@ from pydantic import TypeAdapter, ValidationError from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ApiProviderSchemaType +from core.tools.errors import ApiToolProviderNotFoundError +from core.tools.tool_label_manager import ToolLabelManager from models import Account, Tenant from models.tools import ApiToolProvider from services.tools.api_tools_manage_service import ApiToolManageService @@ -590,30 +594,204 @@ class TestApiToolManageService: with pytest.raises(ValueError, match="you have not added provider"): ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent") - def test_update_api_tool_provider_not_found( + def test_update_api_tool_provider_success( self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): - """Test update raises ValueError when original provider not found.""" fake = Faker() + + # Firmware fix for cache.delete() in update flow + mock_encrypter = mock_external_service_dependencies["encrypter"] + from unittest.mock import MagicMock + + mock_cache = MagicMock() + mock_cache.delete.return_value = None + mock_encrypter.return_value = (mock_encrypter, mock_cache) + + # Get fake account and tenant account, tenant = self._create_test_account_and_tenant( db_session_with_containers, mock_external_service_dependencies ) - with pytest.raises(ValueError, match="does not exists"): - ApiToolManageService.update_api_tool_provider( + # original provider name + original_name = "original-provider" + + # Create original provider + _ = ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=original_name, + icon={"type": "emoji", "value": "🔧"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy="", + custom_disclaimer="", + labels=["old-label"], + ) + + # new provide name and new labels for update + new_name = "updated-provider" + new_labels = ["new-label-1", "new-label-2"] + + # Reset mock history so assertions focus on update path only + mock_external_service_dependencies["encrypter"].reset_mock() + mock_external_service_dependencies["provider_controller"].from_db.reset_mock() + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock() + + # Act: Update the provider with new values + result = ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + # new provider name - changed 1 + provider_name=new_name, + original_provider=original_name, + # new icon - changed 2 + icon={"type": "emoji", "value": "🚀"}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + # new privacy policy - changed 3 + privacy_policy="https://new-policy.com", + # new custom disclaimer - changed 4 + custom_disclaimer="New disclaimer", + # new labels - changed 5 (However, we will not verify this, not this layer responsibility.) + labels=new_labels, + ) + + # Assert: Verify the result + assert result == {"result": "success"} + + # Get the updated provider from the database + updated_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == new_name) + .first() + ) + + # Verify the provider was updated successfully + assert updated_provider is not None + + # Manually refresh to keep object detachment + db_session_with_containers.refresh(updated_provider) + # Verify all the updated fields + # - changed 1 + assert updated_provider.name == new_name + # - changed 2 + icon_data = json.loads(updated_provider.icon) + assert icon_data["type"] == "emoji" + assert icon_data["value"] == "🚀" + # - changed 3 + assert updated_provider.privacy_policy == "https://new-policy.com" + # - changed 4 + assert updated_provider.custom_disclaimer == "New disclaimer" + + # Verify old provider name no longer exists after rename + original_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == original_name) + .first() + ) + assert original_provider is None + + # Verify update flow calls critical collaborators + mock_external_service_dependencies["provider_controller"].from_db.assert_called_once() + mock_external_service_dependencies["encrypter"].assert_called_once() + mock_cache.delete.assert_called_once() + + # Deeply verify on session propagation of labels update logics: + # Since in refactoring, we pass session down to label manager to keep atomicity. + # The assertion here is to verify this. + sig = inspect.signature(ToolLabelManager.update_tool_labels) + args, kwargs = mock_external_service_dependencies["tool_label_manager"].update_tool_labels.call_args + bound_args = sig.bind(*args, **kwargs) + passed_session = bound_args.arguments.get("session") + # Ensure the type: Session + assert isinstance(passed_session, Session), f"Expected Session object, got {type(passed_session)}" + assert passed_session is not None, ( + "Atomicity Failure: Session cannot be passed to Label Manager in update_api_tool_provider" + ) + + def test_update_api_tool_provider_not_found( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update raises ValueError when original provider not found. + + This test verifies: + - Proper error when trying to update a non-existing original provider + - No accidental upsert/new provider creation + - No external dependency invocation on early failure path + """ + # Arrange: Create test account and tenant + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Keep an existing provider in DB to ensure unrelated data remains unchanged + existing_provider_name = "existing-provider" + _ = ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=existing_provider_name, + icon={"type": "emoji", "value": "🔧"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy="https://existing-policy.com", + custom_disclaimer="Existing disclaimer", + labels=["existing-label"], + ) + + # Reset mock history so assertions focus on update failure path only + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock() + mock_external_service_dependencies["encrypter"].reset_mock() + mock_external_service_dependencies["provider_controller"].from_db.reset_mock() + + # Act & Assert: Verify update fails with clear error message + target_new_name = "new-provider-name" + missing_original_name = "missing-original-provider" + with pytest.raises(ApiToolProviderNotFoundError) as exc_info: + _ = ApiToolManageService.update_api_tool_provider( user_id=account.id, tenant_id=tenant.id, - provider_name="new-name", - original_provider="nonexistent", - icon={}, + provider_name=target_new_name, + original_provider=missing_original_name, + icon={"type": "emoji", "value": "🚀"}, credentials={"auth_type": "none"}, _schema_type=ApiProviderSchemaType.OPENAPI, schema=self._create_test_openapi_schema(), - privacy_policy=None, - custom_disclaimer="", - labels=[], + privacy_policy="https://new-policy.com", + custom_disclaimer="New disclaimer", + labels=["new-label"], ) + error = exc_info.value + assert error.provider_name == missing_original_name + assert error.tenant_id == tenant.id + assert error.error_code == "api_tool_provider_not_found" + + # Assert: Existing provider should remain unchanged + existing_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == existing_provider_name) + .first() + ) + assert existing_provider is not None + assert existing_provider.name == existing_provider_name + + # Assert: No new provider should be created + unexpected_new_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == target_new_name) + .first() + ) + assert unexpected_new_provider is None + + # Assert: Early failure should skip all downstream external interactions + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_not_called() + mock_external_service_dependencies["encrypter"].assert_not_called() + mock_external_service_dependencies["provider_controller"].from_db.assert_not_called() + def test_update_api_tool_provider_missing_auth_type( self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index ce2fd2eeb1a..ce5c2bd162f 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -5,9 +5,6 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.variables.input_entities import VariableEntity, VariableEntityType from sqlalchemy.orm import Session from core.app.app_config.entities import ( @@ -21,6 +18,9 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models import Account, Tenant from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index 7c43bf676b0..4dab895135b 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta from uuid import uuid4 -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 4b04c1accb1..fcc15aad426 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -530,22 +531,18 @@ class TestAddDocumentToIndexTask: redis_client.set(indexing_cache_key, "processing", ex=300) # Verify logs exist before processing - existing_logs = ( - db_session_with_containers.query(DatasetAutoDisableLog) - .where(DatasetAutoDisableLog.document_id == document.id) - .all() - ) + existing_logs = db_session_with_containers.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id) + ).all() assert len(existing_logs) == 2 # Act: Execute the task add_document_to_index_task(document.id) # Assert: Verify auto disable logs were deleted - remaining_logs = ( - db_session_with_containers.query(DatasetAutoDisableLog) - .where(DatasetAutoDisableLog.document_id == document.id) - .all() - ) + remaining_logs = db_session_with_containers.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id) + ).all() assert len(remaining_logs) == 0 # Verify index processing occurred normally diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 6cbbe431370..e29ca7ebab4 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -11,6 +11,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType @@ -267,11 +268,13 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Ensure all changes are committed # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_with_image_files( @@ -319,7 +322,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Verify that the task completed successfully by checking the log output @@ -360,14 +365,14 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None # Verify database cleanup db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_dataset_not_found( @@ -410,7 +415,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Document should still exist since cleanup failed - existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first() + existing_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document_id).limit(1) + ) assert existing_document is not None def test_batch_clean_document_task_storage_cleanup_failure( @@ -453,11 +460,13 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted from database - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted from database - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_multiple_documents( @@ -510,12 +519,16 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) + ) assert deleted_file is None def test_batch_clean_document_task_different_doc_forms( @@ -564,7 +577,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None except Exception as e: @@ -574,7 +589,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check if the segment still exists (task may have failed before deletion) - existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + existing_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) if existing_segment is not None: # If segment still exists, the task failed before deletion # This is acceptable in test environments with external service issues @@ -645,12 +662,16 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) + ) assert deleted_file is None def test_batch_clean_document_task_integration_with_real_database( @@ -699,8 +720,16 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Verify initial state - assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 - assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) + == 3 + ) + assert ( + db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == upload_file.id).limit(1)) + is not None + ) # Store original IDs for verification document_id = document.id @@ -720,13 +749,20 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None # Verify final database state - assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 - assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document_id) + ) + == 0 + ) + assert db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index f9ae33b32ff..05827112d4a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -37,13 +38,13 @@ class TestBatchCreateSegmentToIndexTask: from extensions.ext_redis import redis_client # Clear all test data - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -292,12 +293,9 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that segments were created - segments = ( - db_session_with_containers.query(DocumentSegment) - .filter_by(document_id=document.id) - .order_by(DocumentSegment.position) - .all() - ) + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position) + ).all() assert len(segments) == 3 # Verify segment content and metadata @@ -367,11 +365,11 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created (since dataset doesn't exist) - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify no documents were modified - documents = db_session_with_containers.query(Document).all() + documents = db_session_with_containers.scalars(select(Document)).all() assert len(documents) == 0 def test_batch_create_segment_to_index_task_document_not_found( @@ -415,12 +413,14 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify dataset remains unchanged (no segments were added to the dataset) db_session_with_containers.refresh(dataset) - segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + segments_for_dataset = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(segments_for_dataset) == 0 def test_batch_create_segment_to_index_task_document_not_available( @@ -516,7 +516,9 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all() + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id) + ).all() assert len(segments) == 0 def test_batch_create_segment_to_index_task_upload_file_not_found( @@ -560,7 +562,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify document remains unchanged @@ -611,7 +613,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify error handling # Since exception was raised, no segments should be created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify document remains unchanged @@ -682,12 +684,9 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that new segments were created with correct positions - all_segments = ( - db_session_with_containers.query(DocumentSegment) - .filter_by(document_id=document.id) - .order_by(DocumentSegment.position) - .all() - ) + all_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position) + ).all() assert len(all_segments) == 6 # 3 existing + 3 new # Verify position ordering diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 1dd37fbc92c..32bc2fc0bd7 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -52,18 +53,18 @@ class TestCleanDatasetTask: from extensions.ext_redis import redis_client # Clear all test data using the provided session fixture - db_session_with_containers.query(DatasetMetadataBinding).delete() - db_session_with_containers.query(DatasetMetadata).delete() - db_session_with_containers.query(AppDatasetJoin).delete() - db_session_with_containers.query(DatasetQuery).delete() - db_session_with_containers.query(DatasetProcessRule).delete() - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DatasetMetadataBinding)) + db_session_with_containers.execute(delete(DatasetMetadata)) + db_session_with_containers.execute(delete(AppDatasetJoin)) + db_session_with_containers.execute(delete(DatasetQuery)) + db_session_with_containers.execute(delete(DatasetProcessRule)) + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -302,28 +303,40 @@ class TestCleanDatasetTask: # Verify results # Check that dataset-related data was cleaned up - documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + documents = db_session_with_containers.scalars(select(Document).where(Document.dataset_id == dataset.id)).all() assert len(documents) == 0 - segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(segments) == 0 # Check that metadata and bindings were cleaned up - metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(metadata) == 0 - bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(bindings) == 0 # Check that process rules and queries were cleaned up - process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all() + process_rules = db_session_with_containers.scalars( + select(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset.id) + ).all() assert len(process_rules) == 0 - queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all() + queries = db_session_with_containers.scalars( + select(DatasetQuery).where(DatasetQuery.dataset_id == dataset.id) + ).all() assert len(queries) == 0 # Check that app dataset joins were cleaned up - app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all() + app_joins = db_session_with_containers.scalars( + select(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset.id) + ).all() assert len(app_joins) == 0 # Verify index processor was called @@ -414,24 +427,32 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ).all() assert len(remaining_files) == 0 # Check that metadata and bindings were cleaned up - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 - remaining_bindings = ( - db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() - ) + remaining_bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(remaining_bindings) == 0 # Verify index processor was called @@ -485,12 +506,14 @@ class TestCleanDatasetTask: # Check that all data was cleaned up - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 - remaining_segments = ( - db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() - ) + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Recreate data for next test case @@ -538,11 +561,15 @@ class TestCleanDatasetTask: # Verify results - even with vector cleanup failure, documents and segments should be deleted # Check that documents were still deleted despite vector cleanup failure - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that segments were still deleted despite vector cleanup failure - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Verify that index processor was called and failed @@ -622,18 +649,22 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all image files were deleted from database image_file_ids = [f.id for f in image_files] - remaining_image_files = ( - db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all() - ) + remaining_image_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(image_file_ids)) + ).all() assert len(remaining_image_files) == 0 # Verify that storage.delete was called for each image file @@ -738,24 +769,32 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ).all() assert len(remaining_files) == 0 # Check that all metadata and bindings were deleted - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 - remaining_bindings = ( - db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() - ) + remaining_bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(remaining_bindings) == 0 # Verify performance expectations @@ -826,7 +865,9 @@ class TestCleanDatasetTask: # Check that upload file was still deleted from database despite storage failure # Note: When storage operations fail, the upload file may not be deleted # This demonstrates that the cleanup process continues even with storage errors - remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id == upload_file.id) + ).all() # The upload file should still be deleted from the database even if storage cleanup fails # However, this depends on the specific implementation of clean_dataset_task if len(remaining_files) > 0: @@ -976,19 +1017,27 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id == upload_file_id) + ).all() assert len(remaining_files) == 0 # Check that all metadata was deleted - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 # Verify that storage.delete was called diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 926c839c8b6..fa3ac12cf00 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -11,6 +11,8 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy import ColumnElement, func, select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -20,6 +22,14 @@ from tasks.clean_notion_document_task import clean_notion_document_task from tests.test_containers_integration_tests.helpers import generate_valid_password +def _count_documents(session: Session, condition: ColumnElement[bool]) -> int: + return session.scalar(select(func.count()).select_from(Document).where(condition)) or 0 + + +def _count_segments(session: Session, condition: ColumnElement[bool]) -> int: + return session.scalar(select(func.count()).select_from(DocumentSegment).where(condition)) or 0 + + class TestCleanNotionDocumentTask: """Integration tests for clean_notion_document_task using testcontainers.""" @@ -145,24 +155,14 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 3 - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(document_ids)) - .count() - == 6 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(document_ids)) == 3 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(document_ids)) == 6 # Execute cleanup task clean_notion_document_task(document_ids, dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(document_ids)) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(document_ids)) == 0 # Verify index processor was called mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value @@ -322,12 +322,7 @@ class TestCleanNotionDocumentTask: # The task properly handles various index types and document configurations. # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id == document.id) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Reset mock for next iteration mock_index_processor_factory.reset_mock() @@ -410,10 +405,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task([document.id], dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies that segments without index_node_ids # are properly deleted from the database. @@ -499,11 +491,8 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 5 - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() - == 10 - ) + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == 5 + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 10 # Clean up only first 3 documents documents_to_clean = [doc.id for doc in documents[:3]] @@ -513,22 +502,12 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(documents_to_clean, dataset.id) # Verify only specified documents' segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(documents_to_clean)) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(documents_to_clean)) == 0 # Verify remaining documents and segments are intact remaining_docs = [doc.id for doc in documents[3:]] - assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2 - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(remaining_docs)) - .count() - == 4 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(remaining_docs)) == 2 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(remaining_docs)) == 4 # Note: This test successfully verifies partial document cleanup operations. # The database operations work correctly, isolating only the specified documents. @@ -612,19 +591,13 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all segments exist before cleanup - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 4 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 4 # Execute cleanup task clean_notion_document_task([document.id], dataset.id) # Verify all segments are deleted regardless of status - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies database operations. # IndexProcessor verification would require more sophisticated mocking. @@ -794,12 +767,9 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == num_documents assert ( - db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() - == num_documents - ) - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() + _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == num_documents * num_segments_per_doc ) @@ -808,10 +778,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(all_document_ids, dataset.id) # Verify all segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 0 # Note: This test successfully verifies bulk document cleanup operations. # The database efficiently handles large-scale deletions. @@ -906,8 +873,8 @@ class TestCleanNotionDocumentTask: # Verify all data exists before cleanup # Note: There may be documents from previous tests, so we check for at least 3 - assert db_session_with_containers.query(Document).count() >= 3 - assert db_session_with_containers.query(DocumentSegment).count() >= 9 + assert db_session_with_containers.scalar(select(func.count()).select_from(Document)) >= 3 + assert db_session_with_containers.scalar(select(func.count()).select_from(DocumentSegment)) >= 9 # Clean up documents from only the first dataset target_dataset = datasets[0] @@ -918,22 +885,12 @@ class TestCleanNotionDocumentTask: clean_notion_document_task([target_document.id], target_dataset.id) # Verify only documents' segments from target dataset are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id == target_document.id) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == target_document.id) == 0 # Verify documents from other datasets remain intact remaining_docs = [doc.id for doc in all_documents[1:]] - assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2 - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(remaining_docs)) - .count() - == 6 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(remaining_docs)) == 2 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(remaining_docs)) == 6 # Note: This test successfully verifies multi-tenant isolation. # Only documents from the target dataset are affected, maintaining tenant separation. @@ -1028,11 +985,9 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == len( - document_statuses - ) + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == len(document_statuses) assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() + _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == len(document_statuses) * 2 ) @@ -1041,10 +996,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(all_document_ids, dataset.id) # Verify all segments are deleted regardless of status - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 0 # Note: This test successfully verifies cleanup of documents in various states. # All documents are deleted regardless of their indexing status. @@ -1142,20 +1094,14 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 1 - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 3 - ) + assert _count_documents(db_session_with_containers, Document.id == document.id) == 1 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 3 # Execute cleanup task clean_notion_document_task([document.id], dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies cleanup of documents with rich metadata. # The task properly handles complex document structures and metadata fields. diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 9f8e37fc9ec..9084667c314 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -11,6 +11,7 @@ from uuid import uuid4 import pytest from faker import Faker +from sqlalchemy import delete from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client @@ -28,12 +29,12 @@ class TestCreateSegmentToIndexTask: """Clean up database and Redis before each test to ensure isolation.""" # Clear all test data using fixture session - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 13ea94348a6..684097851b6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -175,7 +176,7 @@ class TestDatasetIndexingTaskIntegration: def _query_document(self, db_session_with_containers, document_id: str) -> Document | None: """Return the latest persisted document state.""" - return db_session_with_containers.query(Document).where(Document.id == document_id).first() + return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1)) def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None: """Assert all target documents are persisted in parsing status.""" diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index d457b59d588..48fec441c58 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -11,6 +11,7 @@ from unittest.mock import ANY, Mock, patch import pytest from faker import Faker +from sqlalchemy import select from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -221,7 +222,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load method was called @@ -322,7 +325,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "update") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor clean and load methods were called @@ -431,7 +436,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify that no index processor load was called since no segments exist @@ -564,7 +571,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to error - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.ERROR assert "Test exception during indexing" in updated_document.error @@ -635,7 +644,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with custom index type @@ -711,7 +722,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with the document's index type @@ -815,7 +828,9 @@ class TestDealDatasetVectorIndexTask: # Verify all documents were processed for document in documents: - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load was called multiple times @@ -917,7 +932,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify final document status - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED def test_deal_dataset_vector_index_task_with_disabled_documents( @@ -1027,12 +1044,14 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify only enabled document was processed - updated_enabled_document = db_session_with_containers.query(Document).filter_by(id=enabled_document.id).first() + updated_enabled_document = db_session_with_containers.scalar( + select(Document).where(Document.id == enabled_document.id).limit(1) + ) assert updated_enabled_document.indexing_status == IndexingStatus.COMPLETED # Verify disabled document status remains unchanged - updated_disabled_document = ( - db_session_with_containers.query(Document).filter_by(id=disabled_document.id).first() + updated_disabled_document = db_session_with_containers.scalar( + select(Document).where(Document.id == disabled_document.id).limit(1) ) assert updated_disabled_document.indexing_status == IndexingStatus.COMPLETED # Should not change @@ -1148,12 +1167,14 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify only active document was processed - updated_active_document = db_session_with_containers.query(Document).filter_by(id=active_document.id).first() + updated_active_document = db_session_with_containers.scalar( + select(Document).where(Document.id == active_document.id).limit(1) + ) assert updated_active_document.indexing_status == IndexingStatus.COMPLETED # Verify archived document status remains unchanged - updated_archived_document = ( - db_session_with_containers.query(Document).filter_by(id=archived_document.id).first() + updated_archived_document = db_session_with_containers.scalar( + select(Document).where(Document.id == archived_document.id).limit(1) ) assert updated_archived_document.indexing_status == IndexingStatus.COMPLETED # Should not change @@ -1269,14 +1290,14 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify only completed document was processed - updated_completed_document = ( - db_session_with_containers.query(Document).filter_by(id=completed_document.id).first() + updated_completed_document = db_session_with_containers.scalar( + select(Document).where(Document.id == completed_document.id).limit(1) ) assert updated_completed_document.indexing_status == IndexingStatus.COMPLETED # Verify incomplete document status remains unchanged - updated_incomplete_document = ( - db_session_with_containers.query(Document).filter_by(id=incomplete_document.id).first() + updated_incomplete_document = db_session_with_containers.scalar( + select(Document).where(Document.id == incomplete_document.id).limit(1) ) assert updated_incomplete_document.indexing_status == IndexingStatus.INDEXING # Should not change diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 3e9a0c8f7f0..6e03bd93514 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -9,6 +9,7 @@ The task is responsible for removing document segments from the search index whe from unittest.mock import MagicMock, patch from faker import Faker +from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -471,9 +472,9 @@ class TestDisableSegmentsFromIndexTask: db_session_with_containers.refresh(segments[1]) # Check that segments are re-enabled after error - updated_segments = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() - ) + updated_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + ).all() for segment in updated_segments: assert segment.enabled is True diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index d4021143eff..b6e7e6e5c92 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -12,10 +12,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy import delete, func, select, update from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -30,12 +31,12 @@ class DocumentIndexingSyncTaskTestDataFactory: email=f"{uuid4()}@example.com", name=f"user-{uuid4()}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.flush() - tenant = Tenant(name=f"tenant-{account.id}", status="normal") + tenant = Tenant(name=f"tenant-{account.id}", status=TenantStatus.NORMAL) db_session_with_containers.add(tenant) db_session_with_containers.flush() @@ -254,8 +255,8 @@ class TestDocumentIndexingSyncTask: """Test that task raises error when data_source_info is empty.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None) - db_session_with_containers.query(Document).where(Document.id == context["document"].id).update( - {"data_source_info": None} + db_session_with_containers.execute( + update(Document).where(Document.id == context["document"].id).values(data_source_info=None) ) db_session_with_containers.commit() @@ -274,8 +275,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.ERROR @@ -294,13 +295,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.COMPLETED @@ -319,13 +320,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None @@ -354,7 +355,7 @@ class TestDocumentIndexingSyncTask: context = self._create_notion_sync_context(db_session_with_containers) def _delete_dataset_before_clean() -> str: - db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete() + db_session_with_containers.execute(delete(Dataset).where(Dataset.id == context["dataset"].id)) db_session_with_containers.commit() return "2024-01-02T00:00:00Z" @@ -367,8 +368,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -386,13 +387,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -410,8 +411,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -428,8 +429,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.ERROR diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index d94abf2b40a..a9a8c0f30cb 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import func, select from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -123,13 +124,13 @@ class TestDocumentIndexingUpdateTask: db_session_with_containers.expire_all() # Assert document status updated before reindex - updated = db_session_with_containers.query(Document).where(Document.id == document.id).first() + updated = db_session_with_containers.scalar(select(Document).where(Document.id == document.id).limit(1)) assert updated.indexing_status == IndexingStatus.PARSING assert updated.processing_started_at is not None # Segments should be deleted - remaining = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + remaining = db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) ) assert remaining == 0 @@ -167,8 +168,8 @@ class TestDocumentIndexingUpdateTask: mock_external_dependencies["runner_instance"].run.assert_called_once() # Segments should remain (since clean failed before DB delete) - remaining = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + remaining = db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) ) assert remaining > 0 diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index 6a8e1869580..39c58987fd0 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -317,7 +318,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were updated to parsing status # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -362,14 +363,14 @@ class TestDuplicateDocumentIndexingTasks: # Verify segments were deleted from database # Re-query segments from database using captured IDs to avoid stale ORM instances for seg_id in segment_ids: - deleted_segment = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id == seg_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == seg_id).limit(1) ) assert deleted_segment is None # Verify documents were updated to parsing status for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -438,7 +439,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify only existing documents were updated # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in existing_document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -485,7 +486,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were still updated to parsing status before the exception # Re-query documents from database since _duplicate_document_indexing_task close the session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -543,7 +544,7 @@ class TestDuplicateDocumentIndexingTasks: # Assert: Verify error handling # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "batch upload" in updated_document.error.lower() @@ -585,7 +586,7 @@ class TestDuplicateDocumentIndexingTasks: # Assert: Verify error handling # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "limit" in updated_document.error.lower() @@ -649,7 +650,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) @@ -692,7 +693,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) @@ -736,7 +737,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) @@ -851,7 +852,7 @@ class TestDuplicateDocumentIndexingTasks: # Assert for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.is_paused is True assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.display_status == "paused" diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py index c0ddc27286d..83437119988 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py @@ -14,6 +14,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -41,9 +42,9 @@ class TestSendEmailCodeLoginMailTask: from extensions.ext_redis import redis_client # Clear all test data - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index a16f3ff773b..95a867dbb5c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -3,16 +3,14 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import HumanInputNodeData -from graphon.runtime import GraphRuntimeState, VariablePool +from sqlalchemy import delete from configs import dify_config from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -20,6 +18,9 @@ from core.workflow.human_input_compat import ( MemberRecipient, ) from extensions.ext_storage import storage +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient @@ -30,14 +31,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task @pytest.fixture(autouse=True) def cleanup_database(db_session_with_containers): - db_session_with_containers.query(HumanInputFormRecipient).delete() - db_session_with_containers.query(HumanInputDelivery).delete() - db_session_with_containers.query(HumanInputForm).delete() - db_session_with_containers.query(WorkflowPause).delete() - db_session_with_containers.query(WorkflowRun).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(HumanInputFormRecipient)) + db_session_with_containers.execute(delete(HumanInputDelivery)) + db_session_with_containers.execute(delete(HumanInputForm)) + db_session_with_containers.execute(delete(WorkflowPause)) + db_session_with_containers.execute(delete(WorkflowRun)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py index 212fbd26cd4..d34828c4b14 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select from extensions.ext_redis import redis_client from libs.email_i18n import EmailType @@ -44,9 +45,9 @@ class TestMailInviteMemberTask: def cleanup_database(self, db_session_with_containers): """Clean up database before each test to ensure isolation.""" # Clear all test data - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -491,10 +492,10 @@ class TestMailInviteMemberTask: assert tenant.name is not None # Verify tenant relationship exists - tenant_join = ( - db_session_with_containers.query(TenantAccountJoin) - .filter_by(tenant_id=tenant.id, account_id=pending_account.id) - .first() + tenant_join = db_session_with_containers.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == pending_account.id) + .limit(1) ) assert tenant_join is not None assert tenant_join.role == TenantAccountRole.NORMAL diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 96cf9cebf5e..b43b6228701 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,12 @@ import uuid from unittest.mock import ANY, call, patch import pytest -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType +from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole @@ -20,11 +21,11 @@ from tasks.remove_app_and_related_data_task import ( @pytest.fixture(autouse=True) def cleanup_database(db_session_with_containers): - db_session_with_containers.query(WorkflowDraftVariable).delete() - db_session_with_containers.query(WorkflowDraftVariableFile).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(App).delete() - db_session_with_containers.query(Tenant).delete() + db_session_with_containers.execute(delete(WorkflowDraftVariable)) + db_session_with_containers.execute(delete(WorkflowDraftVariableFile)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(App)) + db_session_with_containers.execute(delete(Tenant)) db_session_with_containers.commit() @@ -127,21 +128,21 @@ class TestDeleteDraftVariablesBatch: result = delete_draft_variables_batch(app1.id, batch_size=100) assert result == 150 - app1_remaining = db_session_with_containers.query(WorkflowDraftVariable).where( - WorkflowDraftVariable.app_id == app1.id + app1_remaining_count = db_session_with_containers.scalar( + select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app1.id) ) - app2_remaining = db_session_with_containers.query(WorkflowDraftVariable).where( - WorkflowDraftVariable.app_id == app2.id + app2_remaining_count = db_session_with_containers.scalar( + select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app2.id) ) - assert app1_remaining.count() == 0 - assert app2_remaining.count() == 100 + assert app1_remaining_count == 0 + assert app2_remaining_count == 100 def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers): """Test deletion when no draft variables exist for the app.""" result = delete_draft_variables_batch(str(uuid.uuid4()), 1000) assert result == 0 - assert db_session_with_containers.query(WorkflowDraftVariable).count() == 0 + assert db_session_with_containers.scalar(select(func.count()).select_from(WorkflowDraftVariable)) == 0 @patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data") @patch("tasks.remove_app_and_related_data_task.logger") @@ -190,12 +191,16 @@ class TestDeleteDraftVariableOffloadData: expected_storage_calls = [call(storage_key) for storage_key in upload_file_keys] mock_storage.delete.assert_has_calls(expected_storage_calls, any_order=True) - remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where( - WorkflowDraftVariableFile.id.in_(file_ids) + remaining_var_files_count = db_session_with_containers.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(file_ids)) ) - remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)) - assert remaining_var_files.count() == 0 - assert remaining_upload_files.count() == 0 + remaining_upload_files_count = db_session_with_containers.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ) + assert remaining_var_files_count == 0 + assert remaining_upload_files_count == 0 @patch("extensions.ext_storage.storage") @patch("tasks.remove_app_and_related_data_task.logging") @@ -217,9 +222,13 @@ class TestDeleteDraftVariableOffloadData: assert result == 1 mock_logging.exception.assert_called_once_with("Failed to delete storage object %s", storage_keys[0]) - remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where( - WorkflowDraftVariableFile.id.in_(file_ids) + remaining_var_files_count = db_session_with_containers.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(file_ids)) ) - remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)) - assert remaining_var_files.count() == 0 - assert remaining_upload_files.count() == 0 + remaining_upload_files_count = db_session_with_containers.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ) + assert remaining_var_files_count == 0 + assert remaining_upload_files_count == 0 diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 159ab51304a..b00d827e37c 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -24,16 +24,16 @@ from dataclasses import dataclass from datetime import timedelta import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus -from sqlalchemy import delete, select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel -from models.account import Tenant, TenantAccountJoin, TenantAccountRole +from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.model import UploadFile from models.workflow import Workflow, WorkflowRun from repositories.sqlalchemy_api_workflow_run_repository import ( @@ -181,7 +181,7 @@ class TestWorkflowPauseIntegration: tenant = Tenant( name="Test Tenant", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -190,7 +190,7 @@ class TestWorkflowPauseIntegration: email="test@example.com", name="Test User", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() @@ -679,9 +679,12 @@ class TestWorkflowPauseIntegration: # Verify only 3 were deleted remaining_count = ( - self.session.query(WorkflowPauseModel) - .filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities])) - .count() + self.session.scalar( + select(func.count(WorkflowPauseModel.id)).where( + WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]) + ) + ) + or 0 ) assert remaining_count == 2 @@ -693,7 +696,7 @@ class TestWorkflowPauseIntegration: tenant2 = Tenant( name="Test Tenant 2", - status="normal", + status=TenantStatus.NORMAL, ) self.session.add(tenant2) self.session.commit() @@ -702,7 +705,7 @@ class TestWorkflowPauseIntegration: email="test2@example.com", name="Test User 2", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) self.session.add(account2) self.session.commit() diff --git a/api/tests/test_containers_integration_tests/trigger/conftest.py b/api/tests/test_containers_integration_tests/trigger/conftest.py index e3832fb2ef2..272bee9630d 100644 --- a/api/tests/test_containers_integration_tests/trigger/conftest.py +++ b/api/tests/test_containers_integration_tests/trigger/conftest.py @@ -11,6 +11,7 @@ from collections.abc import Generator from typing import Any import pytest +from sqlalchemy import delete from sqlalchemy.orm import Session from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -40,9 +41,9 @@ def tenant_and_account(db_session_with_containers: Session) -> Generator[tuple[T yield tenant, account # Cleanup - db_session_with_containers.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() - db_session_with_containers.query(Account).filter_by(id=account.id).delete() - db_session_with_containers.query(Tenant).filter_by(id=tenant.id).delete() + db_session_with_containers.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == tenant.id)) + db_session_with_containers.execute(delete(Account).where(Account.id == account.id)) + db_session_with_containers.execute(delete(Tenant).where(Tenant.id == tenant.id)) db_session_with_containers.commit() @@ -93,14 +94,14 @@ def app_model( ) from models.workflow import Workflow - db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app.id).delete() - db_session_with_containers.query(WorkflowSchedulePlan).filter_by(app_id=app.id).delete() - db_session_with_containers.query(WorkflowWebhookTrigger).filter_by(app_id=app.id).delete() - db_session_with_containers.query(WorkflowPluginTrigger).filter_by(app_id=app.id).delete() - db_session_with_containers.query(AppTrigger).filter_by(app_id=app.id).delete() - db_session_with_containers.query(TriggerSubscription).filter_by(tenant_id=tenant.id).delete() - db_session_with_containers.query(Workflow).filter_by(app_id=app.id).delete() - db_session_with_containers.query(App).filter_by(id=app.id).delete() + db_session_with_containers.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app.id)) + db_session_with_containers.execute(delete(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app.id)) + db_session_with_containers.execute(delete(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id)) + db_session_with_containers.execute(delete(WorkflowPluginTrigger).where(WorkflowPluginTrigger.app_id == app.id)) + db_session_with_containers.execute(delete(AppTrigger).where(AppTrigger.app_id == app.id)) + db_session_with_containers.execute(delete(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant.id)) + db_session_with_containers.execute(delete(Workflow).where(Workflow.app_id == app.id)) + db_session_with_containers.execute(delete(App).where(App.id == app.id)) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 7539bae6855..55aec498781 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -10,7 +10,7 @@ from typing import Any import pytest from flask import Flask, Response from flask.testing import FlaskClient -from graphon.enums import BuiltinNodeTypes +from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config @@ -24,6 +24,7 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus @@ -227,7 +228,9 @@ def test_webhook_trigger_creates_trigger_log( assert response.status_code == 200 db_session_with_containers.expire_all() - logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all() + logs = db_session_with_containers.scalars( + select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id) + ).all() assert logs, "Webhook trigger should create trigger log" @@ -611,7 +614,9 @@ def test_schedule_trigger_creates_trigger_log( # Verify WorkflowTriggerLog was created db_session_with_containers.expire_all() - logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all() + logs = db_session_with_containers.scalars( + select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id) + ).all() assert logs, "Schedule trigger should create WorkflowTriggerLog" assert logs[0].trigger_type == AppTriggerType.TRIGGER_SCHEDULE assert logs[0].root_node_id == schedule_node_id @@ -786,11 +791,12 @@ def test_plugin_trigger_full_chain_with_db_verification( # Verify database records exist db_session_with_containers.expire_all() - plugin_triggers = ( - db_session_with_containers.query(WorkflowPluginTrigger) - .filter_by(app_id=app_model.id, node_id=plugin_node_id) - .all() - ) + plugin_triggers = db_session_with_containers.scalars( + select(WorkflowPluginTrigger).where( + WorkflowPluginTrigger.app_id == app_model.id, + WorkflowPluginTrigger.node_id == plugin_node_id, + ) + ).all() assert plugin_triggers, "WorkflowPluginTrigger record should exist" assert plugin_triggers[0].provider_id == provider_id assert plugin_triggers[0].event_name == "test_event" diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index d6933e2180f..bad246a4bb1 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -145,7 +145,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): - """Test that DB_EXTRAS options are properly merged with default timezone setting""" + """Test that DB_EXTRAS options are merged with the default timezone startup option.""" # Set environment variables monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") @@ -158,15 +158,28 @@ def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): # Create config config = DifyConfig() - # Get engine options - engine_options = config.SQLALCHEMY_ENGINE_OPTIONS - - # Verify options contains both search_path and timezone - options = engine_options["connect_args"]["options"] + options = config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"]["options"] assert "search_path=myschema" in options assert "timezone=UTC" in options +def test_db_session_timezone_override_can_disable_app_level_timezone_injection(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("DB_EXTRAS", "options=-c search_path=myschema") + monkeypatch.setenv("DB_SESSION_TIMEZONE_OVERRIDE", "") + + config = DifyConfig() + + assert config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"] == { + "options": "-c search_path=myschema", + } + + def test_pubsub_redis_url_default(monkeypatch: pytest.MonkeyPatch): os.environ.clear() @@ -223,6 +236,41 @@ def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest. _ = DifyConfig().normalized_pubsub_redis_url +def test_dify_config_exposes_redis_key_prefix_default(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + + config = DifyConfig(_env_file=None) + + assert config.REDIS_KEY_PREFIX == "" + + +def test_dify_config_reads_redis_key_prefix_from_env(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("REDIS_KEY_PREFIX", "enterprise-a") + + config = DifyConfig(_env_file=None) + + assert config.REDIS_KEY_PREFIX == "enterprise-a" + + @pytest.mark.parametrize( ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), [ diff --git a/api/tests/unit_tests/controllers/console/app/test_app_import_api.py b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py new file mode 100644 index 00000000000..9c4678aed31 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py @@ -0,0 +1,139 @@ +"""Unit tests for console app import endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from controllers.console.app import app_import as app_import_module +from services.app_dsl_service import ImportStatus + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +class _Result: + def __init__(self, status: ImportStatus, app_id: str | None = "app-1"): + self.status = status + self.app_id = app_id + + def model_dump(self, mode: str = "json"): + return {"status": self.status, "app_id": self.app_id} + + +def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None: + features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled)) + monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features) + + +def _mock_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + return fake_session + + +class TestAppImportApi: + @pytest.fixture + def api(self): + return app_import_module.AppImportApi() + + def test_import_post_returns_failed_status_and_rolls_back(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.rollback.assert_called_once_with() + session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + + def test_import_post_returns_pending_status_and_commits(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.PENDING), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once_with() + session.rollback.assert_not_called() + assert status == 202 + assert response["status"] == ImportStatus.PENDING + + def test_import_post_updates_webapp_auth_when_enabled(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=True) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + update_access = MagicMock() + monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once_with() + session.rollback.assert_not_called() + update_access.assert_called_once_with("app-123", "private") + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + +class TestAppImportConfirmApi: + @pytest.fixture + def api(self): + return app_import_module.AppImportConfirmApi() + + def test_import_confirm_returns_failed_status_and_rolls_back( + self, api, app, monkeypatch: pytest.MonkeyPatch + ) -> None: + method = _unwrap(api.post) + + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "confirm_import", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"): + response, status = method(import_id="import-1") + + session.rollback.assert_called_once_with() + session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py index 2ac3dc037dd..35d07a987d0 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -138,12 +138,15 @@ def app_models(app_module): def patch_signed_url(monkeypatch, app_module): """Ensure icon URL generation uses a deterministic helper for tests.""" - def _fake_signed_url(key: str | None) -> str | None: - if not key: + def _fake_build_icon_url(_icon_type, key: str | None) -> str | None: + if key is None: + return None + icon_type = str(_icon_type).lower() + if icon_type != "image": return None return f"signed:{key}" - monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + monkeypatch.setattr(app_module, "build_icon_url", _fake_build_icon_url) def _ts(hour: int = 12) -> datetime: diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py index c52bc02420e..2d218dac7e5 100644 --- a/api/tests/unit_tests/controllers/console/app/test_audio.py +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -4,7 +4,6 @@ import io from types import SimpleNamespace import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -21,6 +20,7 @@ from controllers.console.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py index 11b3b3470d6..24b7e39f737 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -33,12 +33,17 @@ def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) paginate_result = MagicMock() + paginate_result.page = 1 + paginate_result.per_page = 20 + paginate_result.total = 0 + paginate_result.has_next = False + paginate_result.items = [] monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"): response = method(app_model=SimpleNamespace(id="app-1")) - assert response is paginate_result + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -71,12 +76,17 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) paginate_result = MagicMock() + paginate_result.page = 1 + paginate_result.per_page = 20 + paginate_result.total = 0 + paginate_result.has_next = False + paginate_result.items = [] monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"): response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT)) - assert response is paginate_result + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py deleted file mode 100644 index f588ab261d8..00000000000 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py +++ /dev/null @@ -1,42 +0,0 @@ -from datetime import datetime -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from controllers.console.app.conversation import _get_conversation - - -def test_get_conversation_mark_read_keeps_updated_at_unchanged(): - app_model = SimpleNamespace(id="app-id") - account = SimpleNamespace(id="account-id") - conversation = MagicMock() - conversation.id = "conversation-id" - - with ( - patch( - "controllers.console.app.conversation.current_account_with_tenant", - return_value=(account, None), - autospec=True, - ), - patch( - "controllers.console.app.conversation.naive_utc_now", - return_value=datetime(2026, 2, 9, 0, 0, 0), - autospec=True, - ), - patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session, - ): - mock_session.scalar.return_value = conversation - - _get_conversation(app_model, "conversation-id") - - statement = mock_session.execute.call_args[0][0] - compiled = statement.compile() - sql_text = str(compiled).lower() - compact_sql_text = sql_text.replace(" ", "") - params = compiled.params - - assert "updated_at=current_timestamp" not in compact_sql_text - assert "updated_at=conversations.updated_at" in compact_sql_text - assert "read_at=:read_at" in compact_sql_text - assert "read_account_id=:read_account_id" in compact_sql_text - assert params["read_at"] == datetime(2026, 2, 9, 0, 0, 0) - assert params["read_account_id"] == "account-id" diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py new file mode 100644 index 00000000000..1a412aff29b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from contextlib import nullcontext +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest +from pydantic import ValidationError + +from controllers.console.app import conversation_variables as conversation_variables_module +from graphon.variables.types import SegmentType + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def test_get_conversation_variables_returns_paginated_response(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + created_at = datetime(2026, 1, 1, tzinfo=UTC) + updated_at = datetime(2026, 1, 2, tzinfo=UTC) + row = SimpleNamespace( + created_at=created_at, + updated_at=updated_at, + to_variable=lambda: SimpleNamespace( + model_dump=lambda: { + "id": "var-1", + "name": "my_var", + "value_type": "string", + "value": "value", + "description": "desc", + } + ), + ) + session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row])) + monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + conversation_variables_module, + "sessionmaker", + lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)), + ) + + with app.test_request_context( + "/console/api/apps/app-1/conversation-variables", + method="GET", + query_string={"conversation_id": "conv-1"}, + ): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response["page"] == 1 + assert response["limit"] == 100 + assert response["total"] == 1 + assert response["has_more"] is False + assert response["data"][0]["id"] == "var-1" + assert response["data"][0]["created_at"] == int(created_at.timestamp()) + assert response["data"][0]["updated_at"] == int(updated_at.timestamp()) + + +def test_get_conversation_variables_normalizes_value_type_and_value(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + row = SimpleNamespace( + created_at=None, + updated_at=None, + to_variable=lambda: SimpleNamespace( + model_dump=lambda: { + "id": "var-2", + "name": "my_var_2", + "value_type": SegmentType.INTEGER, + "value": 42, + "description": None, + } + ), + ) + session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row])) + monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + conversation_variables_module, + "sessionmaker", + lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)), + ) + + with app.test_request_context( + "/console/api/apps/app-1/conversation-variables", + method="GET", + query_string={"conversation_id": "conv-1"}, + ): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response["data"][0]["value_type"] == "number" + assert response["data"][0]["value"] == "42" + + +def test_get_conversation_variables_requires_conversation_id(app) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"): + with pytest.raises(ValidationError): + method(app_model=SimpleNamespace(id="app-1")) diff --git a/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py new file mode 100644 index 00000000000..1af15d8dc63 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py @@ -0,0 +1,138 @@ +import datetime +from types import SimpleNamespace +from unittest.mock import PropertyMock, patch + +from flask import Flask + +from controllers.console import console_ns +from controllers.console.app.mcp_server import AppMCPServerController, AppMCPServerResponse + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class _ValidatedResponse: + def __init__(self, payload): + self._payload = payload + + def model_dump(self, mode="json"): + return self._payload + + +class TestAppMCPServerResponse: + def test_parameters_json_string_parsed(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": '{"key": "value"}', + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == {"key": "value"} + + def test_parameters_invalid_json_returns_original(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": "not-valid-json", + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == "not-valid-json" + + def test_parameters_dict_passthrough(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": {"already": "parsed"}, + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == {"already": "parsed"} + + def test_parameters_json_array_parsed(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": '["a", "b"]', + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == ["a", "b"] + + def test_timestamps_normalized(self): + dt = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": {}, + "created_at": dt, + "updated_at": dt, + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.created_at == int(dt.timestamp()) + assert resp.updated_at == int(dt.timestamp()) + + def test_timestamps_none(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": {}, + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.created_at is None + assert resp.updated_at is None + + +class TestAppMCPServerController: + def test_get_returns_empty_dict_when_server_missing(self): + api = AppMCPServerController() + method = unwrap(api.get) + + with patch("controllers.console.app.mcp_server.db.session.scalar", return_value=None): + response = method(api, app_model=SimpleNamespace(id="app-1")) + + assert response == {} + + def test_post_returns_201(self): + api = AppMCPServerController() + method = unwrap(api.post) + payload = {"parameters": {"timeout": 30}} + app = Flask(__name__) + app.config["TESTING"] = True + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch("controllers.console.app.mcp_server.current_account_with_tenant", return_value=(None, "tenant-1")), + patch("controllers.console.app.mcp_server.db.session.add"), + patch("controllers.console.app.mcp_server.db.session.commit"), + patch("controllers.console.app.mcp_server.AppMCPServer.generate_server_code", return_value="server-code"), + patch( + "controllers.console.app.mcp_server.AppMCPServerResponse.model_validate", + return_value=_ValidatedResponse({"id": "server-1"}), + ), + ): + response, status_code = method( + api, app_model=SimpleNamespace(id="app-1", name="Demo App", description="App description") + ) + + assert response == {"id": "server-1"} + assert status_code == 201 diff --git a/api/tests/unit_tests/controllers/console/app/test_message_api.py b/api/tests/unit_tests/controllers/console/app/test_message_api.py index a76e9588296..c984dbef5d7 100644 --- a/api/tests/unit_tests/controllers/console/app/test_message_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_message_api.py @@ -1,5 +1,7 @@ from __future__ import annotations +from datetime import UTC, datetime + import pytest from controllers.console.app import message as message_module @@ -120,3 +122,24 @@ def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> N response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"]) assert len(response.data) == 2 assert response.data[0] == "What is AI?" + + +def test_message_detail_response_normalizes_aliases_and_timestamp(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test MessageDetailResponse normalizes alias fields and datetime timestamps.""" + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = message_module.MessageDetailResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "conversation_id": "550e8400-e29b-41d4-a716-446655440001", + "inputs": {"foo": "bar"}, + "query": "hello", + "re_sign_file_url_answer": "world", + "from_source": "user", + "status": "normal", + "created_at": created_at, + "message_metadata_dict": {"token_usage": 3}, + } + ) + assert response.answer == "world" + assert response.metadata == {"token_usage": 3} + assert response.created_at == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 36076368805..e91c0a05972 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -1,15 +1,16 @@ from __future__ import annotations +import json from datetime import datetime from types import SimpleNamespace from unittest.mock import Mock import pytest -from graphon.file import File, FileTransferMethod, FileType from werkzeug.exceptions import HTTPException, NotFound from controllers.console.app import workflow as workflow_module from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync +from graphon.file import File, FileTransferMethod, FileType def _unwrap(func): @@ -30,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None: file_list = [ File( tenant_id="t1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="http://u", ) @@ -258,6 +259,63 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure( assert exc.value.description == "invalid workflow graph" +def test_get_published_workflows_marshals_items_before_session_closes(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = workflow_module.PublishedAllWorkflowApi() + handler = _unwrap(api.get) + + session_state = {"open": False} + + class _SessionContext: + def __enter__(self): + session_state["open"] = True + return object() + + def __exit__(self, exc_type, exc, tb): + session_state["open"] = False + return False + + class _SessionMaker: + def begin(self): + return _SessionContext() + + class _Workflow: + @property + def id(self): + assert session_state["open"] is True + return "w1" + + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(workflow_module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker()) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + get_all_published_workflow=lambda **_kwargs: ([_Workflow()], False), + ), + ) + + def _fake_marshal(items, fields): + assert session_state["open"] is True + return [{"id": item.id} for item in items] + + monkeypatch.setattr(workflow_module, "marshal", _fake_marshal) + + with app.test_request_context( + "/apps/app/workflows", + method="GET", + query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"}, + ): + response = handler(api, app_model=SimpleNamespace(id="app", workflow_id="wf-1")) + + assert response == { + "items": [{"id": "w1"}], + "page": 1, + "limit": 10, + "has_more": False, + } + + def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None) @@ -290,3 +348,87 @@ def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.Monk ): with pytest.raises(NotFound): handler(api, app_model=SimpleNamespace(id="app")) + + +def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: pytest.MonkeyPatch) -> None: + app_id_1 = "11111111-1111-1111-1111-111111111111" + app_id_2 = "22222222-2222-2222-2222-222222222222" + signed_avatar_url = "https://files.example.com/signed/avatar-1" + sign_avatar = Mock(return_value=signed_avatar_url) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(get_accessible_app_ids=lambda app_ids, tenant_id: {app_id_1}), + ) + monkeypatch.setattr(workflow_module.file_helpers, "get_signed_file_url", sign_avatar) + + workflow_module.redis_client.hgetall.side_effect = lambda key: ( + { + b"sid-1": json.dumps( + { + "user_id": "u-1", + "username": "Alice", + "avatar": "avatar-file-id", + "sid": "sid-1", + } + ) + } + if key == f"{workflow_module.WORKFLOW_ONLINE_USERS_PREFIX}{app_id_1}" + else {} + ) + + api = workflow_module.WorkflowOnlineUsersApi() + handler = _unwrap(api.get) + + with app.test_request_context( + f"/apps/workflows/online-users?app_ids={app_id_1},{app_id_2}", + method="GET", + ): + response = handler(api) + + assert response == { + "data": [ + { + "app_id": app_id_1, + "users": [ + { + "user_id": "u-1", + "username": "Alice", + "avatar": signed_avatar_url, + "sid": "sid-1", + } + ], + } + ] + } + workflow_module.redis_client.hgetall.assert_called_once_with( + f"{workflow_module.WORKFLOW_ONLINE_USERS_PREFIX}{app_id_1}" + ) + sign_avatar.assert_called_once_with("avatar-file-id") + + +def test_workflow_online_users_rejects_excessive_workflow_ids(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) + accessible_app_ids = Mock(return_value=set()) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(get_accessible_app_ids=accessible_app_ids), + ) + + excessive_ids = ",".join(f"wf-{index}" for index in range(workflow_module.MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS + 1)) + + api = workflow_module.WorkflowOnlineUsersApi() + handler = _unwrap(api.get) + + with app.test_request_context( + f"/apps/workflows/online-users?app_ids={excessive_ids}", + method="GET", + ): + with pytest.raises(HTTPException) as exc: + handler(api) + + assert exc.value.code == 400 + assert "Maximum" in exc.value.description + accessible_app_ids.assert_not_called() diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py new file mode 100644 index 00000000000..a9853f25b01 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from controllers.console.app import workflow_app_log as workflow_app_log_module +from graphon.enums import WorkflowExecutionStatus + + +def test_workflow_app_log_query_parses_bool_and_datetime(): + query = workflow_app_log_module.WorkflowAppLogQuery.model_validate( + { + "detail": "true", + "created_at__before": "2026-01-02T03:04:05Z", + "page": "2", + "limit": "10", + } + ) + + assert query.detail is True + assert query.created_at__before == datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + assert query.page == 2 + assert query.limit == 10 + + +def test_workflow_app_log_pagination_response_normalizes_nested_fields(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = workflow_app_log_module.WorkflowAppLogPaginationResponse.model_validate( + { + "page": 1, + "limit": 20, + "total": 1, + "has_more": False, + "data": [ + { + "id": "log-1", + "workflow_run": { + "id": "run-1", + "status": WorkflowExecutionStatus.SUCCEEDED, + "created_at": created_at, + "finished_at": created_at, + }, + "details": {"trigger_metadata": {}}, + "created_by_account": {"id": "acc-1", "name": "acc", "email": "acc@example.com"}, + "created_at": created_at, + } + ], + } + ).model_dump(mode="json") + + assert response["data"][0]["workflow_run"]["status"] == "succeeded" + assert response["data"][0]["workflow_run"]["created_at"] == int(created_at.timestamp()) + assert response["data"][0]["created_at"] == int(created_at.timestamp()) + + +def test_workflow_archived_log_pagination_response_normalizes_nested_fields(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = workflow_app_log_module.WorkflowArchivedLogPaginationResponse.model_validate( + { + "page": 1, + "limit": 20, + "total": 1, + "has_more": False, + "data": [ + { + "id": "archived-1", + "workflow_run": { + "id": "run-1", + "status": WorkflowExecutionStatus.FAILED, + }, + "trigger_metadata": {"type": "trigger-plugin"}, + "created_by_end_user": { + "id": "eu-1", + "type": "anonymous", + "is_anonymous": True, + "session_id": "session-1", + }, + "created_at": created_at, + } + ], + } + ).model_dump(mode="json") + + assert response["data"][0]["workflow_run"]["status"] == "failed" + assert response["data"][0]["created_at"] == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py new file mode 100644 index 00000000000..85afcf0e60a --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +from contextlib import nullcontext +from dataclasses import dataclass +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +from controllers.console import console_ns +from controllers.console import wraps as console_wraps +from controllers.console.app import workflow_comment as workflow_comment_module +from controllers.console.app import wraps as app_wraps +from libs import login as login_lib +from models.account import Account, AccountStatus, TenantAccountRole + + +def _make_account(role: TenantAccountRole) -> Account: + account = Account(name="tester", email="tester@example.com") + account.status = AccountStatus.ACTIVE + account.role = role + account.id = "account-123" # type: ignore[assignment] + account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] + account._get_current_object = lambda: account # type: ignore[attr-defined] + return account + + +def _make_app() -> SimpleNamespace: + return SimpleNamespace(id="app-123", tenant_id="tenant-123", status="normal", mode="workflow") + + +def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None: + monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) + monkeypatch.setattr(login_lib, "current_user", account) + monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) + monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") + monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model) + monkeypatch.setattr(workflow_comment_module, "current_user", account) + + +def _patch_write_services(monkeypatch: pytest.MonkeyPatch) -> None: + for method_name in ( + "create_comment", + "update_comment", + "delete_comment", + "resolve_comment", + "validate_comment_access", + "create_reply", + "update_reply", + "delete_reply", + ): + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, method_name, MagicMock()) + + +def _patch_payload(payload: dict[str, object] | None): + if payload is None: + return nullcontext() + return patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ) + + +@dataclass(frozen=True) +class WriteCase: + resource_cls: type + method_name: str + path: str + kwargs: dict[str, str] + payload: dict[str, object] | None = None + + +@pytest.mark.parametrize( + "case", + [ + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentListApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments", + kwargs={"app_id": "app-123"}, + payload={"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentDetailApi, + method_name="put", + path="/console/api/apps/app-123/workflow/comments/comment-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + payload={"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentDetailApi, + method_name="delete", + path="/console/api/apps/app-123/workflow/comments/comment-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentResolveApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments/comment-1/resolve", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + payload={"content": "reply", "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyDetailApi, + method_name="put", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies/reply-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1", "reply_id": "reply-1"}, + payload={"content": "reply", "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyDetailApi, + method_name="delete", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies/reply-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1", "reply_id": "reply-1"}, + ), + ], +) +def test_write_endpoints_require_edit_permission(app: Flask, monkeypatch: pytest.MonkeyPatch, case: WriteCase) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.NORMAL) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + _patch_write_services(monkeypatch) + + with app.test_request_context(case.path, method=case.method_name.upper(), json=case.payload): + with _patch_payload(case.payload): + handler = getattr(case.resource_cls(), case.method_name) + with pytest.raises(Forbidden): + handler(**case.kwargs) + + +def test_create_comment_allows_editor(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.EDITOR) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + + create_comment_mock = MagicMock(return_value={"id": "comment-1"}) + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "create_comment", create_comment_mock) + payload = {"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []} + + with app.test_request_context("/console/api/apps/app-123/workflow/comments", method="POST", json=payload): + with _patch_payload(payload): + result = workflow_comment_module.WorkflowCommentListApi().post(app_id="app-123") + + if isinstance(result, tuple): + response = result[0] + else: + response = result + assert response["id"] == "comment-1" + create_comment_mock.assert_called_once_with( + tenant_id="tenant-123", + app_id="app-123", + created_by="account-123", + content="hello", + position_x=1.0, + position_y=2.0, + mentioned_user_ids=[], + ) + + +def test_update_comment_omits_mentions_when_payload_does_not_include_them( + app: Flask, monkeypatch: pytest.MonkeyPatch +) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.EDITOR) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + + update_comment_mock = MagicMock(return_value={"id": "comment-1", "updated_at": datetime(2024, 1, 1, 12, 0, 0)}) + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "update_comment", update_comment_mock) + payload = {"content": "hello", "position_x": 10.0, "position_y": 20.0} + + with app.test_request_context("/console/api/apps/app-123/workflow/comments/comment-1", method="PUT", json=payload): + with _patch_payload(payload): + workflow_comment_module.WorkflowCommentDetailApi().put(app_id="app-123", comment_id="comment-1") + + update_comment_mock.assert_called_once_with( + tenant_id="tenant-123", + app_id="app-123", + comment_id="comment-1", + user_id="account-123", + content="hello", + position_x=10.0, + position_y=20.0, + mentioned_user_ids=None, + ) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index e11102acb1e..c4a81484466 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -6,14 +6,14 @@ from unittest.mock import Mock import pytest from flask import Flask -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from controllers.console import wraps as console_wraps from controllers.console.app import workflow_run as workflow_run_module from controllers.web.error import NotFoundError +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py new file mode 100644 index 00000000000..5363aa154f7 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +from controllers.console.app import workflow_trigger as workflow_trigger_module + + +def test_parser_models_validate(): + parser = workflow_trigger_module.Parser(node_id="node-1") + enable_parser = workflow_trigger_module.ParserEnable( + trigger_id="550e8400-e29b-41d4-a716-446655440000", enable_trigger=True + ) + + assert parser.node_id == "node-1" + assert enable_parser.enable_trigger is True + + +def test_workflow_trigger_response_serializes_datetime(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + trigger = SimpleNamespace( + id="trigger-1", + trigger_type="trigger-plugin", + title="Trigger", + node_id="node-1", + provider_name="provider", + icon="https://example.com/icon", + status="enabled", + created_at=created_at, + updated_at=created_at, + ) + + payload = workflow_trigger_module.WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump( + mode="json" + ) + assert payload["id"] == "trigger-1" + assert payload["created_at"] == "2026-01-02T03:04:05Z" + assert payload["updated_at"] == "2026-01-02T03:04:05Z" + + +def test_webhook_trigger_response_serializes_datetime(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + webhook = { + "id": "webhook-1", + "webhook_id": "whk-1", + "webhook_url": "https://example.com/hook", + "webhook_debug_url": "https://example.com/hook/debug", + "node_id": "node-1", + "created_at": created_at, + } + + payload = workflow_trigger_module.WebhookTriggerResponse.model_validate(webhook).model_dump(mode="json") + assert payload["webhook_id"] == "whk-1" + assert payload["created_at"] == "2026-01-02T03:04:05Z" diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index 740da1f1df1..22b80b748e6 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock, patch import pytest from flask_restx import marshal -from graphon.variables.types import SegmentType from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, @@ -16,6 +15,7 @@ from controllers.console.app.workflow_draft_variable import ( ) from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile @@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url(): # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_upload_file_id", filename="test.jpg", @@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url(): # Create a File object with REMOTE_URL transfer method test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", filename="test.jpg", diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 560971206f6..0cf97da8785 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -14,18 +14,20 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask from flask_restx import Api +from werkzeug.exceptions import Unauthorized from controllers.console.auth.error import ( AuthenticationFailedError, EmailPasswordLoginLimitError, InvalidEmailError, ) -from controllers.console.auth.login import LoginApi, LogoutApi +from controllers.console.auth.login import EmailCodeLoginApi, LoginApi, LogoutApi from controllers.console.error import ( AccountBannedError, AccountInFreezeError, WorkspacesLimitExceeded, ) +from services.entities.auth_entities import LoginFailureReason from services.errors.account import AccountLoginError, AccountPasswordError @@ -34,6 +36,11 @@ def encode_password(password: str) -> str: return base64.b64encode(password.encode("utf-8")).decode() +def encode_code(code: str) -> str: + """Helper to encode verification code as Base64 for testing.""" + return base64.b64encode(code.encode("utf-8")).decode() + + class TestLoginApi: """Test cases for the LoginApi endpoint.""" @@ -197,12 +204,17 @@ class TestLoginApi: mock_get_invitation.return_value = None # Act & Assert - with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")} - ): - login_api = LoginApi() - with pytest.raises(EmailPasswordLoginLimitError): - login_api.post() + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")} + ): + login_api = LoginApi() + with pytest.raises(EmailPasswordLoginLimitError): + login_api.post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "test@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.LOGIN_RATE_LIMITED @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True) @@ -220,12 +232,17 @@ class TestLoginApi: mock_is_frozen.return_value = True # Act & Assert - with app.test_request_context( - "/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")} - ): - login_api = LoginApi() - with pytest.raises(AccountInFreezeError): - login_api.post() + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")} + ): + login_api = LoginApi() + with pytest.raises(AccountInFreezeError): + login_api.post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "frozen@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_IN_FREEZE @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @@ -257,14 +274,20 @@ class TestLoginApi: mock_authenticate.side_effect = AccountPasswordError("Invalid password") # Act & Assert - with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": encode_password("WrongPass123!")} - ): - login_api = LoginApi() - with pytest.raises(AuthenticationFailedError): - login_api.post() + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/login", + method="POST", + json={"email": "test@example.com", "password": encode_password("WrongPass123!")}, + ): + login_api = LoginApi() + with pytest.raises(AuthenticationFailedError): + login_api.post() mock_add_rate_limit.assert_called_once_with("test@example.com") + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "test@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @@ -288,12 +311,19 @@ class TestLoginApi: mock_authenticate.side_effect = AccountLoginError("Account is banned") # Act & Assert - with app.test_request_context( - "/login", method="POST", json={"email": "banned@example.com", "password": encode_password("ValidPass123!")} - ): - login_api = LoginApi() - with pytest.raises(AccountBannedError): - login_api.post() + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/login", + method="POST", + json={"email": "banned@example.com", "password": encode_password("ValidPass123!")}, + ): + login_api = LoginApi() + with pytest.raises(AccountBannedError): + login_api.post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "banned@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @@ -417,6 +447,36 @@ class TestLoginApi: mock_add_rate_limit.assert_not_called() mock_reset_rate_limit.assert_called_once_with("upper@example.com") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") + @patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token") + @patch("controllers.console.auth.login._get_account_with_case_fallback") + def test_email_code_login_logs_banned_account( + self, + mock_get_account, + mock_revoke_token, + mock_get_token_data, + mock_db, + app, + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} + mock_get_account.side_effect = Unauthorized("Account is banned.") + + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/email-code-login/validity", + method="POST", + json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"}, + ): + with pytest.raises(AccountBannedError): + EmailCodeLoginApi().post() + + mock_revoke_token.assert_called_once_with("token-123") + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED + class TestLogoutApi: """Test cases for the LogoutApi endpoint.""" diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 9c9f8da87c1..5136922e88e 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from werkzeug.exceptions import Forbidden, NotFound from controllers.console import console_ns @@ -18,6 +17,7 @@ from controllers.console.datasets.rag_pipeline.datasource_auth import ( DatasourceUpdateProviderNameApi, ) from core.plugin.impl.oauth import OAuthHandler +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index 6ef8ccfdbd3..63950736c54 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Response -from graphon.variables.types import SegmentType from controllers.console import console_ns from controllers.console.app.error import DraftWorkflowNotExist @@ -16,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor ) from controllers.web.error import InvalidArgumentError, NotFoundError from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.variables.types import SegmentType from models.account import Account diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 8555900f4ec..9465936f280 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -1555,7 +1555,17 @@ class TestDatasetApiKeyApi: method = unwrap(api.get) mock_key_1 = MagicMock(spec=ApiToken) + mock_key_1.id = "key-1" + mock_key_1.type = "dataset" + mock_key_1.token = "ds-abc" + mock_key_1.last_used_at = None + mock_key_1.created_at = None mock_key_2 = MagicMock(spec=ApiToken) + mock_key_2.id = "key-2" + mock_key_2.type = "dataset" + mock_key_2.token = "ds-def" + mock_key_2.last_used_at = None + mock_key_2.created_at = None with ( app.test_request_context("/"), @@ -1570,13 +1580,26 @@ class TestDatasetApiKeyApi: ): response = method(api) - assert "items" in response - assert response["items"] == [mock_key_1, mock_key_2] + assert "data" in response + assert len(response["data"]) == 2 + assert response["data"][0]["id"] == "key-1" + assert response["data"][0]["token"] == "ds-abc" + assert response["data"][1]["id"] == "key-2" + assert response["data"][1]["token"] == "ds-def" def test_post_create_api_key_success(self, app): api = DatasetApiKeyApi() method = unwrap(api.post) + mock_token = MagicMock() + mock_token.id = "new-key-id" + mock_token.last_used_at = None + mock_token.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) + + mock_api_token_cls = MagicMock() + mock_api_token_cls.return_value = mock_token + mock_api_token_cls.generate_api_key.return_value = "dataset-abc123" + with ( app.test_request_context("/"), patch( @@ -1588,8 +1611,8 @@ class TestDatasetApiKeyApi: return_value=3, ), patch( - "controllers.console.datasets.datasets.ApiToken.generate_api_key", - return_value="dataset-abc123", + "controllers.console.datasets.datasets.ApiToken", + mock_api_token_cls, ), patch( "controllers.console.datasets.datasets.db.session.add", @@ -1603,9 +1626,11 @@ class TestDatasetApiKeyApi: response, status = method(api) assert status == 200 - assert isinstance(response, ApiToken) - assert response.token == "dataset-abc123" - assert response.type == "dataset" + assert isinstance(response, dict) + assert response["id"] == "new-key-id" + assert response["token"] == "dataset-abc123" + assert response["type"] == "dataset" + assert response["created_at"] is not None def test_post_exceed_max_keys(self, app): api = DatasetApiKeyApi() @@ -1747,6 +1772,21 @@ class TestDatasetApiBaseUrlApi: assert response["api_base_url"] == "http://localhost:5000/v1" + def test_get_api_base_url_no_double_v1(self, app): + api = DatasetApiBaseUrlApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", + "https://example.com/v1", + ), + ): + response = method(api) + + assert response["api_base_url"] == "https://example.com/v1" + class TestDatasetRetrievalSettingApi: def test_get_success(self, app): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index ce2278de4f8..d9b02ac4533 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -215,17 +216,23 @@ class TestDatasetDocumentListApi: method = unwrap(api.post) payload = {"indexing_technique": "economy"} + created_dataset = SimpleNamespace(id="ds-1", name="Dataset", indexing_technique="economy") + created_document = SimpleNamespace(id="doc-1", name="Document", doc_metadata_details=None) with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=created_dataset, + ), patch( "controllers.console.datasets.datasets_document.DocumentService.document_create_args_validate", return_value=None, ), patch( "controllers.console.datasets.datasets_document.DocumentService.save_document_with_dataset_id", - return_value=([MagicMock()], "batch-1"), + return_value=([created_document], "batch-1"), ), ): response = method(api, "ds-1") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py index 161d0c41e83..514bbbe0402 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_external.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -1,3 +1,4 @@ +from importlib import import_module from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -11,6 +12,7 @@ from controllers.console.datasets.external import ( BedrockRetrievalApi, ExternalApiTemplateApi, ExternalApiTemplateListApi, + ExternalApiUseCheckApi, ExternalDatasetCreateApi, ExternalKnowledgeHitTestingApi, ) @@ -19,6 +21,8 @@ from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService from services.knowledge_service import ExternalDatasetTestService +external_controller = import_module("controllers.console.datasets.external") + def unwrap(func): while hasattr(func, "__wrapped__"): @@ -44,10 +48,11 @@ def current_user(): @pytest.fixture(autouse=True) -def mock_auth(mocker, current_user): - mocker.patch( - "controllers.console.datasets.external.current_account_with_tenant", - return_value=(current_user, "tenant-1"), +def mock_auth(monkeypatch, current_user): + monkeypatch.setattr( + external_controller, + "current_account_with_tenant", + lambda: (current_user, "tenant-1"), ) @@ -136,6 +141,26 @@ class TestExternalApiTemplateApi: method(api, "api-id") +class TestExternalApiUseCheckApi: + def test_get_scopes_usage_check_to_current_tenant(self, app): + api = ExternalApiUseCheckApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + ExternalDatasetService, + "external_knowledge_api_use_check", + return_value=(True, 2), + ) as mock_use_check, + ): + response, status = method(api, "api-id") + + assert status == 200 + assert response == {"is_using": True, "count": 2} + mock_use_check.assert_called_once_with("api-id", "tenant-1") + + class TestExternalDatasetCreateApi: def test_create_success(self, app): api = ExternalDatasetCreateApi() diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py index 726c0a5cf39..09ed2aaf697 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py @@ -99,6 +99,57 @@ class TestHitTestingApi: assert "records" in result assert result["records"] == [] + def test_hit_testing_success_with_optional_record_fields(self, app, dataset, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "what is vector search", + } + records = [ + { + "segment": None, + "child_chunks": [], + "score": None, + "tsne_position": None, + "files": [], + "summary": None, + } + ] + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingPayload, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: payload), + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + return_value=dataset, + ), + patch.object( + HitTestingApi, + "hit_testing_args_check", + ), + patch.object( + HitTestingApi, + "perform_hit_testing", + return_value={"query": payload["query"], "records": records}, + ), + ): + result = method(api, dataset_id) + + assert result["query"] == payload["query"] + assert result["records"] == records + def test_hit_testing_dataset_not_found(self, app, dataset_id): api = HitTestingApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index 710c9be684c..e4acd91b76c 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -21,6 +20,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService diff --git a/api/tests/unit_tests/controllers/console/explore/test_audio.py b/api/tests/unit_tests/controllers/console/explore/test_audio.py index 66c9ba48c59..b4b57022e28 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_audio.py +++ b/api/tests/unit_tests/controllers/console/explore/test_audio.py @@ -2,7 +2,6 @@ from io import BytesIO from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import controllers.console.explore.audio as audio_module @@ -20,6 +19,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 2e4ca4f2a44..145cc9cdd7e 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError, NotFound import controllers.console.explore.message as module @@ -22,6 +21,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py index 02c7507ea72..76c863577a9 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import controllers.console.explore.recommended_app as module +from models.model import AppMode, IconType def unwrap(func): @@ -90,3 +91,48 @@ class TestRecommendedAppApi: service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111") assert result == result_data + + +class TestRecommendedAppResponseModels: + def test_recommended_app_info_response_computes_icon_url(self): + with patch.object(module, "build_icon_url", return_value="https://signed/icon.png"): + payload = module.RecommendedAppInfoResponse.model_validate( + { + "id": "app-1", + "name": "App", + "mode": AppMode.CHAT, + "icon": "icon.png", + "icon_type": IconType.IMAGE, + "icon_background": "#fff", + } + ).model_dump(mode="json") + + assert payload["icon_url"] == "https://signed/icon.png" + + def test_recommended_app_list_response_serialization(self): + response = module.RecommendedAppListResponse.model_validate( + { + "recommended_apps": [ + { + "app": { + "id": "app-1", + "name": "App", + "mode": "chat", + "icon": "icon.png", + "icon_type": "emoji", + "icon_background": "#fff", + }, + "app_id": "app-1", + "description": "desc", + "category": "cat", + "position": 1, + "is_listed": True, + "can_trial": False, + } + ], + "categories": ["cat"], + } + ).model_dump(mode="json") + + assert response["recommended_apps"][0]["app_id"] == "app-1" + assert response["categories"] == ["cat"] diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 04beb31389c..3625056af98 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import controllers.console.explore.trial as module @@ -26,6 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from models import Account from models.account import TenantStatus from models.model import AppMode @@ -94,7 +94,7 @@ class TestTrialAppWorkflowRunApi: with app.test_request_context("/"): with pytest.raises(NotWorkflowAppError): - method(MagicMock(mode=AppMode.CHAT)) + method(api, MagicMock(mode=AppMode.CHAT)) def test_success(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -106,7 +106,7 @@ class TestTrialAppWorkflowRunApi: patch.object(module.AppGenerateService, "generate", return_value=MagicMock()), patch.object(module.RecommendedAppService, "add_trial_app_record"), ): - result = method(trial_app_workflow) + result = method(api, trial_app_workflow) assert result is not None @@ -124,7 +124,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_quota_exceeded(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -140,7 +140,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderQuotaExceededError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_model_not_support(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -156,7 +156,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_invoke_error(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -172,7 +172,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(CompletionRequestError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_rate_limit_error(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -188,7 +188,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(InvokeRateLimitHttpError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_value_error(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -204,7 +204,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ValueError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_generic_exception(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -220,7 +220,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(InternalServerError): - method(trial_app_workflow) + method(api, trial_app_workflow) class TestTrialChatApi: @@ -566,7 +566,7 @@ class TestTrialMessageSuggestedQuestionApi: with app.test_request_context("/"): with pytest.raises(NotChatAppError): - method(api, MagicMock(mode="completion"), str(uuid4())) + method(MagicMock(mode="completion"), str(uuid4())) def test_success(self, app, trial_app_chat, account): api = module.TrialMessageSuggestedQuestionApi() @@ -581,7 +581,7 @@ class TestTrialMessageSuggestedQuestionApi: return_value=["q1", "q2"], ), ): - result = method(api, trial_app_chat, str(uuid4())) + result = method(trial_app_chat, str(uuid4())) assert result == {"data": ["q1", "q2"]} @@ -599,7 +599,7 @@ class TestTrialMessageSuggestedQuestionApi: ), ): with pytest.raises(NotFound): - method(api, trial_app_chat, str(uuid4())) + method(trial_app_chat, str(uuid4())) class TestTrialAppParameterApi: @@ -931,7 +931,7 @@ class TestTrialAppWorkflowTaskStopApi: with app.test_request_context("/"): with pytest.raises(NotWorkflowAppError): - method(trial_app_chat, str(uuid4())) + method(api, trial_app_chat, str(uuid4())) def test_success(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowTaskStopApi() @@ -944,7 +944,7 @@ class TestTrialAppWorkflowTaskStopApi: patch.object(module.AppQueueManager, "set_stop_flag_no_user_check") as mock_set_flag, patch.object(module.GraphEngineManager, "send_stop_command") as mock_send_cmd, ): - result = method(trial_app_workflow, task_id) + result = method(api, trial_app_workflow, task_id) assert result == {"result": "success"} mock_set_flag.assert_called_once_with(task_id) diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index e89b89c8b11..2be5a21f289 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -1,9 +1,11 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, PropertyMock, patch import pytest from flask import Flask from werkzeug.exceptions import Forbidden +import controllers.console.tag.tags as module from controllers.console import console_ns from controllers.console.tag.tags import ( TagBindingCreateApi, @@ -83,13 +85,20 @@ class TestTagListApi: ), patch( "controllers.console.tag.tags.TagService.get_tags", - return_value=[{"id": "1", "name": "tag"}], + return_value=[ + SimpleNamespace( + id="1", + name="tag", + type=TagType.KNOWLEDGE, + binding_count=1, + ) + ], ), ): result, status = method(api) assert status == 200 - assert isinstance(result, list) + assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}] def test_post_success(self, app, admin_user, tag, payload_patch): api = TagListApi() @@ -113,6 +122,7 @@ class TestTagListApi: assert status == 200 assert result["name"] == "test-tag" + assert result["binding_count"] == "0" def test_post_forbidden(self, app, readonly_user, payload_patch): api = TagListApi() @@ -158,7 +168,7 @@ class TestTagUpdateDeleteApi: result, status = method(api, "tag-1") assert status == 200 - assert result["binding_count"] == 3 + assert result["binding_count"] == "3" def test_patch_forbidden(self, app, readonly_user, payload_patch): api = TagUpdateDeleteApi() @@ -277,3 +287,13 @@ class TestTagBindingDeleteApi: ): with pytest.raises(Forbidden): method(api) + + +class TestTagResponseModel: + def test_tag_response_normalizes_enum_type(self): + payload = module.TagResponse.model_validate( + {"id": "tag-1", "name": "tag", "type": TagType.KNOWLEDGE, "binding_count": 1} + ).model_dump(mode="json") + + assert payload["type"] == "knowledge" + assert payload["binding_count"] == "1" diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 9afc1c41661..26ff264f18c 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -11,7 +11,7 @@ from controllers.console.workspace.account import ( ChangeEmailSendEmailApi, CheckEmailUnique, ) -from models import Account +from models import Account, AccountStatus from services.account_service import AccountService @@ -20,7 +20,7 @@ def app(): app = Flask(__name__) app.config["TESTING"] = True app.config["RESTX_MASK_HEADER"] = "X-Fields" - app.login_manager = SimpleNamespace(_load_user=lambda: None) + app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None) return app @@ -33,7 +33,7 @@ def _build_account(email: str, account_id: str = "acc", tenant: object | None = account = Account(name=account_id, email=email) account.email = email account.id = account_id - account.status = "active" + account.status = AccountStatus.ACTIVE account._current_tenant = tenant_obj return account @@ -68,7 +68,10 @@ class TestChangeEmailSend: mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") mock_current_account.return_value = (mock_account, None) - mock_get_change_data.return_value = {"email": "current@example.com"} + mock_get_change_data.return_value = { + "email": "current@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + } mock_send_email.return_value = "token-abc" with app.test_request_context( @@ -85,12 +88,55 @@ class TestChangeEmailSend: email="new@example.com", old_email="current@example.com", language="en-US", - phase="new_email", + phase=AccountService.CHANGE_EMAIL_PHASE_NEW, ) mock_extract_ip.assert_called_once() mock_is_ip_limit.assert_called_once_with("127.0.0.1") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.send_change_email_email") + @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified( + self, + mock_features, + mock_csrf, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_get_change_data, + mock_current_account, + mock_db, + app, + ): + """GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_account = _build_account("current@example.com", "acc1") + mock_current_account.return_value = (mock_account, None) + mock_get_change_data.return_value = { + "email": "current@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } + + with app.test_request_context( + "/account/change-email", + method="POST", + json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailSendEmailApi().post() + + mock_send_email.assert_not_called() + class TestChangeEmailValidity: @patch("controllers.console.wraps.db") @@ -122,7 +168,12 @@ class TestChangeEmailValidity: mock_account = _build_account("user@example.com", "acc2") mock_current_account.return_value = (mock_account, None) mock_is_rate_limit.return_value = False - mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"} + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } mock_generate_token.return_value = (None, "new-token") with app.test_request_context( @@ -138,11 +189,169 @@ class TestChangeEmailValidity: mock_add_rate.assert_not_called() mock_revoke_token.assert_called_once_with("token-123") mock_generate_token.assert_called_once_with( - "user@example.com", code="1234", old_email="old@example.com", additional_data={} + "user@example.com", + code="1234", + old_email="old@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + }, ) mock_reset_rate.assert_called_once_with("user@example.com") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_upgrade_new_phase_token_to_new_verified( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "new@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW, + } + mock_generate_token.return_value = (None, "new-verified-token") + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "new@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + response = ChangeEmailCheckApi().post() + + assert response == {"is_valid": True, "email": "new@example.com", "token": "new-verified-token"} + mock_generate_token.assert_called_once_with( + "new@example.com", + code="1234", + old_email="old@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + }, + ) + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_validity_when_token_phase_is_unknown( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + """A token whose phase marker is a string but not a known transition must be rejected.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: "something_else", + } + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "user@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailCheckApi().post() + + mock_revoke_token.assert_not_called() + mock_generate_token.assert_not_called() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_validity_when_token_has_no_phase( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + """A token minted without a phase marker (e.g. a hand-crafted token) must not validate.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + } + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "user@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailCheckApi().post() + + mock_revoke_token.assert_not_called() + mock_generate_token.assert_not_called() + class TestChangeEmailReset: @patch("controllers.console.wraps.db") @@ -175,7 +384,11 @@ class TestChangeEmailReset: mock_current_account.return_value = (current_user, None) mock_is_freeze.return_value = False mock_check_unique.return_value = True - mock_get_data.return_value = {"old_email": "OLD@example.com"} + mock_get_data.return_value = { + "email": "new@example.com", + "old_email": "OLD@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } mock_account_after_update = _build_account("new@example.com", "acc3-updated") mock_update_account.return_value = mock_account_after_update @@ -194,6 +407,155 @@ class TestChangeEmailReset: mock_send_notify.assert_called_once_with(email="new@example.com") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_reset_when_token_phase_is_not_new_verified( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app, + ): + """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + # Simulate a token straight out of step #1 (phase=old_email) — exactly + # the replay used in the advisory PoC. + mock_get_data.return_value = { + "email": "old@example.com", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "attacker@example.com", "token": "token-from-step1"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailResetApi().post() + + mock_revoke_token.assert_not_called() + mock_update_account.assert_not_called() + mock_send_notify.assert_not_called() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_reset_when_token_email_differs_from_payload_new_email( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app, + ): + """A verified token for address A must not be replayed to change to address B.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + mock_get_data.return_value = { + "email": "verified@example.com", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "attacker@example.com", "token": "token-verified"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailResetApi().post() + + mock_revoke_token.assert_not_called() + mock_update_account.assert_not_called() + mock_send_notify.assert_not_called() + + +class TestAccountServiceSendChangeEmailEmail: + """Service-level coverage for the phase-bound changes in `send_change_email_email`.""" + + def test_should_raise_value_error_for_invalid_phase(self): + with pytest.raises(ValueError, match="phase must be one of"): + AccountService.send_change_email_email( + email="user@example.com", + old_email="user@example.com", + phase="old_email_verified", + ) + + @patch("services.account_service.send_change_mail_task") + @patch("services.account_service.AccountService.change_email_rate_limiter") + @patch("services.account_service.AccountService.generate_change_email_token") + def test_should_stamp_phase_into_generated_token( + self, + mock_generate_token, + mock_rate_limiter, + mock_mail_task, + ): + mock_rate_limiter.is_rate_limited.return_value = False + mock_generate_token.return_value = ("123456", "the-token") + + returned = AccountService.send_change_email_email( + email="user@example.com", + old_email="user@example.com", + language="en-US", + phase=AccountService.CHANGE_EMAIL_PHASE_NEW, + ) + + assert returned == "the-token" + mock_generate_token.assert_called_once_with( + "user@example.com", + None, + old_email="user@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW, + }, + ) + mock_mail_task.delay.assert_called_once() + mock_rate_limiter.increment_rate_limit.assert_called_once_with("user@example.com") + class TestAccountDeletionFeedback: @patch("controllers.console.wraps.db") @@ -233,15 +595,20 @@ class TestCheckEmailUnique: def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): - session = MagicMock() + mock_session = MagicMock() first = MagicMock() first.scalar_one_or_none.return_value = None second = MagicMock() expected_account = MagicMock() second.scalar_one_or_none.return_value = expected_account - session.execute.side_effect = [first, second] + mock_session.execute.side_effect = [first, second] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session) + mock_factory = MagicMock() + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + + with patch("services.account_service.session_factory", mock_factory): + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account - assert session.execute.call_count == 2 + assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/test_workspace_members.py b/api/tests/unit_tests/controllers/console/test_workspace_members.py index 368892b9228..239fec84308 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_members.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_members.py @@ -12,7 +12,7 @@ from models.account import Account, TenantAccountRole def app(): flask_app = Flask(__name__) flask_app.config["TESTING"] = True - flask_app.login_manager = SimpleNamespace(_load_user=lambda: None) + flask_app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None) return flask_app diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index 9c42ee9529a..b2f949c6e2b 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -11,9 +11,10 @@ from unittest.mock import MagicMock import pytest from flask import Flask from flask.views import MethodView +from werkzeug.exceptions import Forbidden + from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from werkzeug.exceptions import Forbidden if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index fb9eec98cb5..168479af1ea 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic_core import ValidationError from werkzeug.exceptions import Forbidden @@ -14,6 +13,7 @@ from controllers.console.workspace.model_providers import ( ModelProviderValidateApi, PreferredProviderTypeUpdateApi, ) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" INVALID_UUID = "123" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index c829327bc7a..f0d32f81fb5 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -2,8 +2,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from controllers.console.workspace.models import ( DefaultModelApi, @@ -16,6 +14,8 @@ from controllers.console.workspace.models import ( ModelProviderModelParameterRuleApi, ModelProviderModelValidateApi, ) +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError def unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index b2d13dbbdf3..e82a29f0454 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -18,6 +18,7 @@ from controllers.console.workspace.workspace import ( CustomConfigWorkspaceApi, SwitchWorkspaceApi, TenantApi, + TenantInfoResponse, TenantListApi, WebappLogoWorkspaceApi, WorkspaceInfoApi, @@ -435,6 +436,23 @@ class TestTenantApi: assert status == 200 +class TestTenantInfoResponse: + def test_tenant_info_response_normalizes_enum_and_datetime(self): + created_at = naive_utc_now() + payload = TenantInfoResponse.model_validate( + { + "id": "t1", + "status": TenantStatus.NORMAL, + "plan": CloudPlan.TEAM, + "created_at": created_at, + } + ).model_dump(mode="json") + + assert payload["status"] == "normal" + assert payload["plan"] == "team" + assert payload["created_at"] == int(created_at.timestamp()) + + class TestSwitchWorkspaceApi: def test_switch_success(self, app): api = SwitchWorkspaceApi() diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py index 4a5f91cc5d5..71381e6a2b4 100644 --- a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -18,6 +18,7 @@ from controllers.inner_api.app.dsl import ( InnerAppDSLImportPayload, _get_active_account, ) +from models.account import AccountStatus from services.app_dsl_service import ImportStatus @@ -63,7 +64,7 @@ class TestGetActiveAccount: @patch("controllers.inner_api.app.dsl.db") def test_returns_active_account(self, mock_db): mock_account = MagicMock() - mock_account.status = "active" + mock_account.status = AccountStatus.ACTIVE mock_db.session.scalar.return_value = mock_account result = _get_active_account("user@example.com") @@ -74,7 +75,7 @@ class TestGetActiveAccount: @patch("controllers.inner_api.app.dsl.db") def test_returns_none_for_inactive_account(self, mock_db): mock_account = MagicMock() - mock_account.status = "banned" + mock_account.status = AccountStatus.BANNED mock_db.session.scalar.return_value = mock_account result = _get_active_account("banned@example.com") @@ -103,13 +104,15 @@ class TestEnterpriseAppDSLImport: @pytest.fixture def _mock_import_deps(self): """Patch db, Session, and AppDslService for import handler tests.""" + mock_session = MagicMock() + mock_session.__enter__ = MagicMock(return_value=mock_session) + mock_session.__exit__ = MagicMock(return_value=False) with ( patch("controllers.inner_api.app.dsl.db"), - patch("controllers.inner_api.app.dsl.Session") as mock_session, + patch("controllers.inner_api.app.dsl.Session", return_value=mock_session), patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls, ): - mock_session.return_value.__enter__ = MagicMock(return_value=MagicMock()) - mock_session.return_value.__exit__ = MagicMock(return_value=False) + self._mock_session = mock_session self._mock_dsl = MagicMock() mock_dsl_cls.return_value = self._mock_dsl yield @@ -145,6 +148,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 200 assert body["status"] == "completed" mock_account.set_tenant_id.assert_called_once_with("ws-123") + self._mock_session.commit.assert_called_once_with() + self._mock_session.rollback.assert_not_called() @pytest.mark.usefixtures("_mock_import_deps") @patch("controllers.inner_api.app.dsl._get_active_account") @@ -160,6 +165,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 202 assert body["status"] == "pending" + self._mock_session.commit.assert_called_once_with() + self._mock_session.rollback.assert_not_called() @pytest.mark.usefixtures("_mock_import_deps") @patch("controllers.inner_api.app.dsl._get_active_account") @@ -175,6 +182,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 400 assert body["status"] == "failed" + self._mock_session.rollback.assert_called_once_with() + self._mock_session.commit.assert_not_called() @patch("controllers.inner_api.app.dsl._get_active_account") def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask): diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py index 957d7fbd9be..d1b09c3a584 100644 --- a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py @@ -2,6 +2,7 @@ Unit tests for inner_api plugin decorators """ +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -40,17 +41,22 @@ class TestTenantUserPayload: class TestGetUser: """Test get_user function""" + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_return_existing_user_by_id(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): + def test_should_return_existing_user_by_id( + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask + ): """Test returning existing user when found by ID""" # Arrange mock_user = MagicMock() mock_user.id = "user123" mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.return_value = mock_user + mock_session.scalar.return_value = mock_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -58,13 +64,45 @@ class TestGetUser: # Assert assert result == mock_user - mock_session.get.assert_called_once() + mock_session.scalar.assert_called_once() + @patch("controllers.inner_api.plugin.wraps.select") + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_not_resolve_non_anonymous_users_across_tenants( + self, + mock_db, + mock_sessionmaker, + mock_enduser_class, + mock_select, + app: Flask, + ): + """Test that explicit user IDs remain scoped to the current tenant.""" + # Arrange + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + mock_new_user = MagicMock() + mock_new_user.tenant_id = "tenant-current" + mock_enduser_class.return_value = mock_new_user + + # Act + with app.app_context(): + result = get_user("tenant-current", "foreign-user-id") + + # Assert + assert result == mock_new_user + mock_session.get.assert_not_called() + mock_session.scalar.assert_called_once() + mock_session.add.assert_called_once_with(mock_new_user) + + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") def test_should_return_existing_anonymous_user_by_session_id( - self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask ): """Test returning existing anonymous user by session_id""" # Arrange @@ -72,8 +110,9 @@ class TestGetUser: mock_user.session_id = "anonymous_session" mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - # non-anonymous path uses session.get(); anonymous uses session.scalar() - mock_session.get.return_value = mock_user + mock_session.scalar.return_value = mock_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -82,17 +121,22 @@ class TestGetUser: # Assert assert result == mock_user + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_create_new_user_when_not_found(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): + def test_should_create_new_user_when_not_found( + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask + ): """Test creating new user when not found in database""" # Arrange mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.return_value = None + mock_session.scalar.return_value = None mock_new_user = MagicMock() mock_enduser_class.return_value = mock_new_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -133,7 +177,7 @@ class TestGetUser: # Arrange mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.side_effect = Exception("Database error") + mock_session.scalar.side_effect = Exception("Database error") # Act & Assert with app.app_context(): @@ -232,11 +276,11 @@ class TestGetUserTenant: class PluginTestPayload: """Simple test payload class""" - def __init__(self, data: dict): + def __init__(self, data: dict[str, Any]): self.value = data.get("value") @classmethod - def model_validate(cls, data: dict): + def model_validate(cls, data: dict[str, Any]): return cls(data) @@ -277,7 +321,7 @@ class TestPluginData: # Arrange class InvalidPayload: @classmethod - def model_validate(cls, data: dict): + def model_validate(cls, data: dict[str, Any]): raise Exception("Validation failed") @plugin_data(payload_type=InvalidPayload) diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py index 56a8f949631..7d2193adc69 100644 --- a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py @@ -20,6 +20,7 @@ from controllers.inner_api.workspace.workspace import ( WorkspaceCreatePayload, WorkspaceOwnerlessPayload, ) +from models.account import TenantStatus class TestWorkspaceCreatePayload: @@ -98,7 +99,7 @@ class TestEnterpriseWorkspace: mock_tenant.id = "tenant-id" mock_tenant.name = "My Workspace" mock_tenant.plan = "sandbox" - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_tenant.created_at = now mock_tenant.updated_at = now mock_tenant_svc.create_tenant.return_value = mock_tenant @@ -162,7 +163,7 @@ class TestEnterpriseWorkspaceNoOwnerEmail: mock_tenant.name = "My Workspace" mock_tenant.encrypt_public_key = "pub-key" mock_tenant.plan = "sandbox" - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_tenant.custom_config = None mock_tenant.created_at = now mock_tenant.updated_at = now diff --git a/api/tests/unit_tests/controllers/service_api/app/test_app.py b/api/tests/unit_tests/controllers/service_api/app/test_app.py index 1507bf7a5fc..f48ace427d8 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_app.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_app.py @@ -10,6 +10,7 @@ from flask import Flask from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameterApi from controllers.service_api.app.error import AppUnavailableError +from models.account import TenantStatus from models.model import App, AppMode from tests.unit_tests.conftest import setup_mock_tenant_account_query @@ -62,7 +63,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL # Mock DB queries for app and tenant mock_db.session.get.side_effect = [ @@ -110,7 +111,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -151,7 +152,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -190,7 +191,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -253,7 +254,7 @@ class TestAppMetaApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -321,7 +322,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -378,7 +379,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app, @@ -424,7 +425,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app, @@ -476,7 +477,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app, diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index 5a8cb4619f4..c16ebad7399 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -13,7 +13,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -30,6 +29,7 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( @@ -95,30 +95,6 @@ class TestTextToAudioPayload: assert payload.streaming is True -# --------------------------------------------------------------------------- -# AudioService Interface Tests -# --------------------------------------------------------------------------- - - -class TestAudioServiceInterface: - """Test AudioService method interfaces exist.""" - - def test_transcript_asr_method_exists(self): - """Test that AudioService.transcript_asr exists.""" - assert hasattr(AudioService, "transcript_asr") - assert callable(AudioService.transcript_asr) - - def test_transcript_tts_method_exists(self): - """Test that AudioService.transcript_tts exists.""" - assert hasattr(AudioService, "transcript_tts") - assert callable(AudioService.transcript_tts) - - -# --------------------------------------------------------------------------- -# Audio Service Tests -# --------------------------------------------------------------------------- - - class TestAudioServiceInterface: """Test suite for AudioService interface methods.""" diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 57681d8f5bc..3364c07e623 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -16,7 +16,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -35,6 +34,7 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index dbd06677d8a..4fb8ecf7844 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -15,6 +15,7 @@ Focus on: import sys import uuid +from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import Mock, patch @@ -29,11 +30,16 @@ from controllers.service_api.app.conversation import ( ConversationRenameApi, ConversationRenamePayload, ConversationVariableDetailApi, + ConversationVariableInfiniteScrollPaginationResponse, + ConversationVariableResponse, ConversationVariablesApi, ConversationVariablesQuery, ConversationVariableUpdatePayload, ) from controllers.service_api.app.error import NotChatAppError +from fields._value_type_serializer import serialize_value_type +from graphon.variables import StringSegment +from graphon.variables.types import SegmentType from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService from services.errors.conversation import ( @@ -261,6 +267,72 @@ class TestConversationVariableUpdatePayload: assert payload.value == nested +class TestConversationVariableResponseModels: + def test_variable_response_normalizes_value_type_and_timestamps(self): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "description": "desc", + "created_at": created_at, + "updated_at": created_at, + } + ) + assert response.value_type == "number" + assert response.value == "1" + assert response.created_at == int(created_at.timestamp()) + assert response.updated_at == int(created_at.timestamp()) + + def test_variable_response_normalizes_string_value_type_alias(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER.value, + } + ) + + assert response.value_type == "number" + + def test_variable_response_normalizes_callable_exposed_type(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()), + } + ) + + assert response.value_type == "string" + + def test_serialize_value_type_supports_segments_and_mappings(self): + assert serialize_value_type(StringSegment(value="hello")) == "string" + assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number" + + def test_variable_pagination_response(self): + response = ConversationVariableInfiniteScrollPaginationResponse.model_validate( + { + "limit": 1, + "has_more": False, + "data": [ + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": "string", + "value": "bar", + } + ], + } + ) + assert response.limit == 1 + assert response.has_more is False + assert len(response.data) == 1 + assert response.data[0].name == "foo" + + class TestConversationAppModeValidation: """Test app mode validation for conversation endpoints.""" @@ -549,6 +621,44 @@ class TestConversationVariablesApiController: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") + def test_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + monkeypatch.setattr( + ConversationService, + "get_conversational_variable", + lambda *_args, **_kwargs: SimpleNamespace( + limit=1, + has_more=False, + data=[ + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "created_at": created_at, + "updated_at": created_at, + } + ], + ), + ) + + api = ConversationVariablesApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(mode=AppMode.CHAT.value) + end_user = SimpleNamespace() + + with app.test_request_context( + "/conversations/1/variables?limit=20", + method="GET", + ): + result = handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") + + assert result["limit"] == 1 + assert result["has_more"] is False + assert result["data"][0]["value_type"] == "number" + assert result["data"][0]["value"] == "1" + assert result["data"][0]["created_at"] == int(created_at.timestamp()) + class TestConversationVariableDetailApiController: def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -602,3 +712,41 @@ class TestConversationVariableDetailApiController: c_id="00000000-0000-0000-0000-000000000001", variable_id="00000000-0000-0000-0000-000000000002", ) + + def test_update_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + monkeypatch.setattr( + ConversationService, + "update_conversation_variable", + lambda *_args, **_kwargs: { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "created_at": created_at, + "updated_at": created_at, + }, + ) + + api = ConversationVariableDetailApi() + handler = _unwrap(api.put) + app_model = SimpleNamespace(mode=AppMode.CHAT.value) + end_user = SimpleNamespace() + + with app.test_request_context( + "/conversations/1/variables/2", + method="PUT", + json={"value": 1}, + ): + result = handler( + api, + app_model=app_model, + end_user=end_user, + c_id="00000000-0000-0000-0000-000000000001", + variable_id="00000000-0000-0000-0000-000000000002", + ) + + assert result["id"] == "550e8400-e29b-41d4-a716-446655440000" + assert result["value_type"] == "number" + assert result["value"] == "1" + assert result["created_at"] == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index cfa21bf2dd6..da09ec13cea 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -15,11 +15,11 @@ Focus on: import sys import uuid +from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.enums import WorkflowExecutionStatus from werkzeug.exceptions import BadRequest, NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -36,6 +36,7 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from graphon.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError @@ -43,6 +44,22 @@ from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService +def _make_mock_workflow_run(run_id: str = "run-1"): + run = Mock() + run.id = run_id + run.workflow_id = "wf-1" + run.status = WorkflowExecutionStatus.SUCCEEDED + run.inputs = {"input": "value"} + run.outputs_dict = {"output": "value"} + run.error = None + run.total_steps = 1 + run.total_tokens = 10 + run.created_at = datetime(2026, 1, 1, tzinfo=UTC) + run.finished_at = datetime(2026, 1, 1, tzinfo=UTC) + run.elapsed_time = 0.1 + return run + + class TestWorkflowRunPayload: """Test suite for WorkflowRunPayload Pydantic model.""" @@ -359,7 +376,7 @@ class TestWorkflowRunDetailApi: handler(api, app_model=app_model, workflow_run_id="run") def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None: - run = SimpleNamespace(id="run") + run = _make_mock_workflow_run(run_id="run") repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run) workflow_module = sys.modules["controllers.service_api.app.workflow"] monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) @@ -373,7 +390,10 @@ class TestWorkflowRunDetailApi: handler = _unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1") - assert handler(api, app_model=app_model, workflow_run_id="run") == run + result = handler(api, app_model=app_model, workflow_run_id="run") + assert result["id"] == "run" + assert result["workflow_id"] == "wf-1" + assert result["status"] == "succeeded" class TestWorkflowRunApi: @@ -490,7 +510,7 @@ class TestWorkflowAppLogApi: monkeypatch.setattr( WorkflowAppService, "get_paginate_workflow_app_logs", - lambda *_args, **_kwargs: {"items": [], "total": 0}, + lambda *_args, **_kwargs: {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}, ) api = WorkflowAppLogApi() @@ -500,7 +520,7 @@ class TestWorkflowAppLogApi: with app.test_request_context("/workflows/logs", method="GET"): response = handler(api, app_model=app_model) - assert response == {"items": [], "total": 0} + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} # ============================================================================= @@ -527,9 +547,8 @@ def mock_workflow_app(): class TestWorkflowRunDetailApiGet: """Test suite for WorkflowRunDetailApi.get() endpoint. - ``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``) - and ``@service_api_ns.marshal_with``. We call the unwrapped method - directly; ``marshal_with`` is a no-op when calling directly. + ``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``), + and we call the unwrapped method directly in tests. """ @patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory") @@ -542,9 +561,7 @@ class TestWorkflowRunDetailApiGet: mock_workflow_app, ): """Test successful workflow run detail retrieval.""" - mock_run = Mock() - mock_run.id = "run-1" - mock_run.status = "succeeded" + mock_run = _make_mock_workflow_run(run_id="run-1") mock_repo = Mock() mock_repo.get_workflow_run_by_id.return_value = mock_run mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo @@ -558,7 +575,8 @@ class TestWorkflowRunDetailApiGet: api = WorkflowRunDetailApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id) - assert result == mock_run + assert result["id"] == mock_run.id + assert result["status"] == "succeeded" @patch("controllers.service_api.app.workflow.db") def test_get_workflow_run_wrong_app_mode(self, mock_db, app): @@ -622,8 +640,7 @@ class TestWorkflowTaskStopApiPost: class TestWorkflowAppLogApiGet: """Test suite for WorkflowAppLogApi.get() endpoint. - ``get`` is wrapped by ``@validate_app_token`` and - ``@service_api_ns.marshal_with``. + ``get`` is wrapped by ``@validate_app_token``. """ @patch("controllers.service_api.app.workflow.WorkflowAppService") @@ -637,6 +654,10 @@ class TestWorkflowAppLogApiGet: ): """Test successful workflow log retrieval.""" mock_pagination = Mock() + mock_pagination.page = 1 + mock_pagination.limit = 20 + mock_pagination.total = 0 + mock_pagination.has_more = False mock_pagination.data = [] mock_svc_instance = Mock() mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination @@ -661,4 +682,4 @@ class TestWorkflowAppLogApiGet: api = WorkflowAppLogApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app) - assert result == mock_pagination + assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index 4b8e3a738cb..eda270258d5 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,8 +1,7 @@ from types import SimpleNamespace -from graphon.enums import WorkflowExecutionStatus - from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField +from graphon.enums import WorkflowExecutionStatus def test_workflow_run_status_field_with_enum() -> None: diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index 12d5e7345d2..288659b192d 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -22,6 +22,8 @@ import pytest from werkzeug.exceptions import Forbidden, NotFound from controllers.service_api.dataset.document import ( + DeprecatedDocumentAddByTextApi, + DeprecatedDocumentUpdateByTextApi, DocumentAddByFileApi, DocumentAddByTextApi, DocumentApi, @@ -1005,7 +1007,7 @@ class TestDocumentAddByTextApi: # Act with app.test_request_context( - f"/datasets/{mock_dataset.id}/document/create_by_text", + f"/datasets/{mock_dataset.id}/document/create-by-text", method="POST", json={ "name": "Test Document", @@ -1037,7 +1039,7 @@ class TestDocumentAddByTextApi: # Act & Assert with app.test_request_context( - f"/datasets/{mock_dataset.id}/document/create_by_text", + f"/datasets/{mock_dataset.id}/document/create-by-text", method="POST", json={"name": "Test Document", "text": "Content"}, headers={"Authorization": "Bearer test_token"}, @@ -1066,7 +1068,7 @@ class TestDocumentAddByTextApi: # Act & Assert with app.test_request_context( - f"/datasets/{mock_dataset.id}/document/create_by_text", + f"/datasets/{mock_dataset.id}/document/create-by-text", method="POST", json={"name": "Test Document", "text": "Content"}, headers={"Authorization": "Bearer test_token"}, @@ -1093,6 +1095,20 @@ class TestArchivedDocumentImmutableError: assert error.code == 403 +class TestDocumentTextRouteDeprecation: + """Test that legacy underscore text routes stay marked deprecated.""" + + def test_create_by_text_legacy_alias_is_deprecated(self): + """Ensure only the legacy create-by-text alias is marked deprecated.""" + assert DeprecatedDocumentAddByTextApi.post.__apidoc__["deprecated"] is True + assert DocumentAddByTextApi.post.__apidoc__.get("deprecated") is not True + + def test_update_by_text_legacy_alias_is_deprecated(self): + """Ensure only the legacy update-by-text alias is marked deprecated.""" + assert DeprecatedDocumentUpdateByTextApi.post.__apidoc__["deprecated"] is True + assert DocumentUpdateByTextApi.post.__apidoc__.get("deprecated") is not True + + # ============================================================================= # Endpoint tests for DocumentUpdateByTextApi, DocumentAddByFileApi, # DocumentUpdateByFileApi. @@ -1162,7 +1178,7 @@ class TestDocumentUpdateByTextApiPost: doc_id = str(uuid.uuid4()) with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_text", + f"/datasets/{mock_dataset.id}/documents/{doc_id}/update-by-text", method="POST", json={"name": "Updated Doc", "text": "New content"}, headers={"Authorization": "Bearer test_token"}, @@ -1195,7 +1211,7 @@ class TestDocumentUpdateByTextApiPost: doc_id = str(uuid.uuid4()) with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_text", + f"/datasets/{mock_dataset.id}/documents/{doc_id}/update-by-text", method="POST", json={"name": "Doc", "text": "Content"}, headers={"Authorization": "Bearer test_token"}, diff --git a/api/tests/unit_tests/controllers/service_api/test_site.py b/api/tests/unit_tests/controllers/service_api/test_site.py deleted file mode 100644 index c0b40d070a5..00000000000 --- a/api/tests/unit_tests/controllers/service_api/test_site.py +++ /dev/null @@ -1,270 +0,0 @@ -""" -Unit tests for Service API Site controller -""" - -import uuid -from unittest.mock import Mock, patch - -import pytest -from werkzeug.exceptions import Forbidden - -from controllers.service_api.app.site import AppSiteApi -from models.account import TenantStatus -from models.model import App, Site -from tests.unit_tests.conftest import setup_mock_tenant_account_query - - -class TestAppSiteApi: - """Test suite for AppSiteApi""" - - @pytest.fixture - def mock_app_model(self): - """Create a mock App model with tenant.""" - app = Mock(spec=App) - app.id = str(uuid.uuid4()) - app.tenant_id = str(uuid.uuid4()) - app.status = "normal" - app.enable_api = True - - mock_tenant = Mock() - mock_tenant.id = app.tenant_id - mock_tenant.status = TenantStatus.NORMAL - app.tenant = mock_tenant - - return app - - @pytest.fixture - def mock_site(self): - """Create a mock Site model.""" - site = Mock(spec=Site) - site.id = str(uuid.uuid4()) - site.app_id = str(uuid.uuid4()) - site.title = "Test Site" - site.icon = "icon-url" - site.icon_background = "#ffffff" - site.description = "Site description" - site.copyright = "Copyright 2024" - site.privacy_policy = "Privacy policy text" - site.custom_disclaimer = "Custom disclaimer" - site.default_language = "en-US" - site.prompt_public = True - site.show_workflow_steps = True - site.use_icon_as_answer_icon = False - site.chat_color_theme = "light" - site.chat_color_theme_inverted = False - site.icon_type = "image" - site.created_at = "2024-01-01T00:00:00" - site.updated_at = "2024-01-01T00:00:00" - return site - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_success( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - mock_site, - ): - """Test successful retrieval of site configuration.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - # Mock wraps.db for authentication - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site.db for site query - mock_db.session.scalar.return_value = mock_site - - # Act - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - response = api.get() - - # Assert - assert response["title"] == "Test Site" - assert response["icon"] == "icon-url" - assert response["description"] == "Site description" - mock_db.session.scalar.assert_called_once() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_not_found( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - ): - """Test that Forbidden is raised when site is not found.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site query to return None - mock_db.session.scalar.return_value = None - - # Act & Assert - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - with pytest.raises(Forbidden): - api.get() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_tenant_archived( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - mock_site, - ): - """Test that Forbidden is raised when tenant is archived.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site query - mock_db.session.scalar.return_value = mock_site - - # Set tenant status to archived AFTER authentication - mock_app_model.tenant.status = TenantStatus.ARCHIVE - - # Act & Assert - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - with pytest.raises(Forbidden): - api.get() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_queries_by_app_id( - self, mock_wraps_db, mock_validate_token, mock_current_app, mock_db, mock_user_logged_in, app, mock_app_model - ): - """Test that site is queried using the app model's id.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - mock_site = Mock(spec=Site) - mock_site.id = str(uuid.uuid4()) - mock_site.app_id = mock_app_model.id - mock_site.title = "Test Site" - mock_site.icon = "icon-url" - mock_site.icon_background = "#ffffff" - mock_site.description = "Site description" - mock_site.copyright = "Copyright 2024" - mock_site.privacy_policy = "Privacy policy text" - mock_site.custom_disclaimer = "Custom disclaimer" - mock_site.default_language = "en-US" - mock_site.prompt_public = True - mock_site.show_workflow_steps = True - mock_site.use_icon_as_answer_icon = False - mock_site.chat_color_theme = "light" - mock_site.chat_color_theme_inverted = False - mock_site.icon_type = "image" - mock_site.created_at = "2024-01-01T00:00:00" - mock_site.updated_at = "2024-01-01T00:00:00" - mock_db.session.scalar.return_value = mock_site - - # Act - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - api.get() - - # Assert - # The query was executed successfully (site returned), which validates the correct query was made - mock_db.session.scalar.assert_called_once() diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py index cbfc8fa6130..a6ca441801b 100644 --- a/api/tests/unit_tests/controllers/web/test_audio.py +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -8,7 +8,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.audio import AudioApi, TextApi from controllers.web.error import ( @@ -22,6 +21,7 @@ from controllers.web.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py index 49039d03fe1..4f8d848637d 100644 --- a/api/tests/unit_tests/controllers/web/test_completion.py +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -7,7 +7,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from controllers.web.error import ( @@ -19,6 +18,7 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError def _completion_app() -> SimpleNamespace: diff --git a/api/tests/unit_tests/controllers/web/test_message_endpoints.py b/api/tests/unit_tests/controllers/web/test_message_endpoints.py index 89ab93d8d4d..da88b109a81 100644 --- a/api/tests/unit_tests/controllers/web/test_message_endpoints.py +++ b/api/tests/unit_tests/controllers/web/test_message_endpoints.py @@ -129,12 +129,6 @@ class TestMessageSuggestedQuestionApi: with pytest.raises(NotChatAppError): MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id) - def test_wrong_mode_raises(self, app: Flask) -> None: - msg_id = uuid4() - with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): - with pytest.raises(NotChatAppError): - MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id) - @patch("controllers.web.message.MessageService.get_suggested_questions_after_answer") def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None: msg_id = uuid4() diff --git a/api/tests/unit_tests/controllers/web/test_pydantic_models.py b/api/tests/unit_tests/controllers/web/test_pydantic_models.py index dcf81337123..bceb65b89f6 100644 --- a/api/tests/unit_tests/controllers/web/test_pydantic_models.py +++ b/api/tests/unit_tests/controllers/web/test_pydantic_models.py @@ -198,7 +198,7 @@ class TestMessageListQuery: assert q.limit == 20 def test_invalid_conversation_id(self) -> None: - with pytest.raises(ValidationError, match="not a valid uuid"): + with pytest.raises(ValidationError, match="must be a valid UUID"): MessageListQuery(conversation_id="bad") def test_limit_bounds(self) -> None: @@ -216,7 +216,7 @@ class TestMessageListQuery: def test_invalid_first_id(self) -> None: cid = str(uuid4()) - with pytest.raises(ValidationError, match="not a valid uuid"): + with pytest.raises(ValidationError, match="must be a valid UUID"): MessageListQuery(conversation_id=cid, first_id="invalid") diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py index 0661c02578c..a01587d64a1 100644 --- a/api/tests/unit_tests/controllers/web/test_web_login.py +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -4,9 +4,12 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from jwt import InvalidTokenError +from werkzeug.exceptions import Unauthorized import services.errors.account from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi +from services.entities.auth_entities import LoginFailureReason def encode_code(code: str) -> str: @@ -115,13 +118,18 @@ class TestLoginApi: def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None: from controllers.console.error import AccountBannedError - with app.test_request_context( - "/web/login", - method="POST", - json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, - ): - with pytest.raises(AccountBannedError): - LoginApi().post() + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AccountBannedError): + LoginApi().post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED @patch( "controllers.web.login.WebAppAuthService.authenticate", @@ -130,13 +138,87 @@ class TestLoginApi: def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None: from controllers.console.auth.error import AuthenticationFailedError - with app.test_request_context( - "/web/login", - method="POST", - json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, - ): - with pytest.raises(AuthenticationFailedError): - LoginApi().post() + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AuthenticationFailedError): + LoginApi().post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS + + @patch( + "controllers.web.login.WebAppAuthService.authenticate", + side_effect=services.errors.account.AccountNotFoundError(), + ) + def test_login_account_not_found(self, mock_auth: MagicMock, app: Flask) -> None: + from controllers.console.auth.error import AuthenticationFailedError + + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "missing@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AuthenticationFailedError): + LoginApi().post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "missing@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_NOT_FOUND + + @patch("controllers.web.login.WebAppAuthService.get_email_code_login_data", return_value=None) + def test_email_code_login_logs_invalid_token(self, mock_get_token_data: MagicMock, app: Flask) -> None: + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/email-code-login/validity", + method="POST", + json={"email": "user@example.com", "code": encode_code("123456"), "token": "token-123"}, + ): + with pytest.raises(InvalidTokenError): + EmailCodeLoginApi().post() + + mock_get_token_data.assert_called_once_with("token-123") + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_EMAIL_CODE_TOKEN + + @patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token") + @patch( + "controllers.web.login.WebAppAuthService.get_user_through_email", + side_effect=Unauthorized("Account is banned."), + ) + @patch( + "controllers.web.login.WebAppAuthService.get_email_code_login_data", + return_value={"email": "User@Example.com", "code": "123456"}, + ) + def test_email_code_login_logs_banned_account( + self, + mock_get_token_data: MagicMock, + mock_get_user: MagicMock, + mock_revoke_token: MagicMock, + app: Flask, + ) -> None: + from controllers.console.error import AccountBannedError + + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/email-code-login/validity", + method="POST", + json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"}, + ): + with pytest.raises(AccountBannedError): + EmailCodeLoginApi().post() + + mock_get_token_data.assert_called_once_with("token-123") + mock_revoke_token.assert_called_once_with("token-123") + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED class TestLoginStatusApi: diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py index bc7aea0ef92..cde8820e006 100644 --- a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -2,11 +2,11 @@ import json from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.llm_entities import LLMUsage from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError +from graphon.model_runtime.entities.llm_entities import LLMUsage class DummyRunner(CotAgentRunner): diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py index 97206019b9f..ea8cc8aa864 100644 --- a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -1,9 +1,9 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.agent.cot_chat_agent_runner import CotChatAgentRunner +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from tests.unit_tests.core.agent.conftest import ( DummyAgentConfig, DummyAppConfig, diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py index defc8b4b642..2f5873d865f 100644 --- a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -1,6 +1,8 @@ import json import pytest + +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -8,8 +10,6 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner - # ----------------------------- # Fixtures # ----------------------------- diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py index a44a0650eb4..17ab5babcbd 100644 --- a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -3,6 +3,11 @@ from typing import Any from unittest.mock import MagicMock import pytest + +from core.agent.errors import AgentMaxIterationError +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueMessageFileEvent from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.message_entities import ( DocumentPromptMessageContent, @@ -11,11 +16,6 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from core.agent.errors import AgentMaxIterationError -from core.agent.fc_agent_runner import FunctionCallAgentRunner -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueMessageFileEvent - # ============================== # Dummy Helper Classes # ============================== diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py index 5ee66da94ab..186b4a501d3 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -2,8 +2,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.entities.model_entities import ModelStatus @@ -12,6 +10,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey class TestModelConfigConverter: diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py index e2f3c16335f..d9fe7004ff7 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py @@ -1,9 +1,9 @@ import pytest -from graphon.variables.input_entities import VariableEntityType from core.app.app_config.easy_ui_based_app.variables.manager import ( BasicVariablesConfigManager, ) +from graphon.variables.input_entities import VariableEntityType class TestBasicVariablesConfigManagerConvert: diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index 8bde9c1f979..11b53dd0f97 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,8 +1,7 @@ +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from graphon.file import FileTransferMethod, FileUploadConfig, ImageConfig from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager - def test_convert_with_vision(): config = { diff --git a/api/tests/unit_tests/core/app/app_config/test_entities.py b/api/tests/unit_tests/core/app/app_config/test_entities.py index 000f83cd5a2..f2bc3076dac 100644 --- a/api/tests/unit_tests/core/app/app_config/test_entities.py +++ b/api/tests/unit_tests/core/app/app_config/test_entities.py @@ -1,10 +1,10 @@ import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.app_config.entities import ( DatasetRetrieveConfigEntity, PromptTemplateEntity, ) +from graphon.variables.input_entities import VariableEntity, VariableEntityType class TestAppConfigEntities: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 061719d15a5..45d4b0e3212 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -3,12 +3,12 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 -from graphon.variables import SegmentType from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from factories import variable_factory +from graphon.variables import SegmentType from models import ConversationVariable, Workflow MINIMAL_GRAPH = { @@ -134,6 +134,7 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( + patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, @@ -150,7 +151,9 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): # Setup mocks - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_session_class.return_value.__enter__.return_value = MagicMock() mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() @@ -177,7 +180,6 @@ class TestAdvancedChatAppRunnerConversationVariables: # Note: Since we're mocking ConversationVariable.from_variable, # we can't directly check the id, but we can verify add_all was called assert mock_session.add_all.called, "Session add_all should have been called" - assert mock_session.commit.called, "Session commit should have been called" def test_no_variables_creates_all(self): """Test that all conversation variables are created when none exist in DB.""" @@ -278,6 +280,7 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( + patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, @@ -295,7 +298,9 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): # Setup mocks - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_session_class.return_value.__enter__.return_value = MagicMock() mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() @@ -326,7 +331,6 @@ class TestAdvancedChatAppRunnerConversationVariables: # Verify that all variables were created assert len(added_items) == 2, "Should have added both variables" assert mock_session.add_all.called, "Session add_all should have been called" - assert mock_session.commit.called, "Session commit should have been called" def test_all_variables_exist_no_changes(self): """Test that no changes are made when all variables already exist in DB.""" @@ -429,6 +433,7 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( + patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, @@ -445,7 +450,9 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): # Setup mocks - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_session_class.return_value.__enter__.return_value = MagicMock() mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() @@ -465,4 +472,3 @@ class TestAdvancedChatAppRunnerConversationVariables: # Verify that no variables were added assert not mock_session.add_all.called, "Session add_all should not have been called" - assert mock_session.commit.called, "Session commit should still be called" diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py index 079df0b4e61..5d8faee8971 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -93,6 +93,16 @@ def _patch_common_run_deps(runner: AdvancedChatAppRunner): scalar=lambda *a, **k: MagicMock(), ), ), + sessionmaker=MagicMock( + return_value=MagicMock( + begin=MagicMock( + return_value=MagicMock( + __enter__=lambda s: MagicMock(scalars=MagicMock(return_value=MagicMock(all=lambda: []))), + __exit__=lambda *a, **k: False, + ), + ), + ), + ), select=MagicMock(), db=MagicMock(engine=MagicMock()), RedisChannel=MagicMock(), diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py index e9fdeefee4d..f2df35d7d02 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -1,7 +1,5 @@ from collections.abc import Generator -from graphon.enums import WorkflowNodeExecutionStatus - from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.entities.task_entities import ( ChatbotAppBlockingResponse, @@ -12,6 +10,7 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, ) +from graphon.enums import WorkflowNodeExecutionStatus class TestAdvancedChatGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index a6d85989556..99a386cd45e 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -6,8 +6,6 @@ from types import SimpleNamespace from unittest import mock import pytest -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module from core.app.entities.app_invoke_entities import InvokeFrom @@ -19,6 +17,8 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import StreamEvent +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent from models.model import AppMode, EndUser diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 82b2e51019f..29fd63c063a 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -4,8 +4,6 @@ from contextlib import contextmanager from types import SimpleNamespace import pytest -from graphon.enums import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.advanced_chat.generate_task_pipeline import ( @@ -49,6 +47,8 @@ from core.app.entities.task_entities import ( ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables +from graphon.enums import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models.enums import MessageStatus from models.model import AppMode, EndUser diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py index 7dc43581501..80f7f94b1ac 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -1,12 +1,12 @@ import contextlib import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class DummyAccount: diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py index 08250bc3b6f..4567b354804 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -1,10 +1,10 @@ import pytest -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.agent.entities import AgentEntity from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.moderation.base import ModerationError +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py index 68bcffb0e83..8f3c41701b3 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -2,7 +2,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.chat.app_runner import ChatAppRunner @@ -10,6 +9,7 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.moderation.base import ModerationError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index f255d2c7df2..b3ea1a464f8 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index 4a94a2b4f1b..201923e0e40 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -1,11 +1,11 @@ from types import SimpleNamespace import pytest -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.workflow.system_variables import build_system_variables from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.runtime import GraphRuntimeState, VariablePool def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 328cd12f12e..dd6cd0e919d 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,10 +1,9 @@ from collections.abc import Mapping, Sequence +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from graphon.variables.segments import ArrayFileSegment, FileSegment -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter - class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Test class for WorkflowResponseConverter._fetch_files_from_variable_value method""" @@ -12,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: def create_test_file(self, file_id: str = "test_file_1") -> File: """Create a test File object""" return File( - id=file_id, - type=FileType.DOCUMENT, + file_id=file_id, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related_123", filename=f"{file_id}.txt", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index bc11bf41744..1bef6f69cdd 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -1,13 +1,12 @@ from datetime import UTC, datetime from types import SimpleNamespace -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter(): diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index c9e146ff126..936ac37e55a 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -1,11 +1,10 @@ from types import SimpleNamespace -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter() -> WorkflowResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 0fde7565d24..b3c0eb74faa 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -10,8 +10,6 @@ from typing import Any from unittest.mock import Mock import pytest -from graphon.entities import WorkflowStartReason -from graphon.enums import BuiltinNodeTypes from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -27,6 +25,8 @@ from core.app.entities.queue_entities import ( QueueNodeSucceededEvent, ) from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py index 619d66085a4..aa2085177e1 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent import core.app.apps.completion.app_runner as module from core.app.apps.completion.app_runner import CompletionAppRunner from core.moderation.base import ModerationError +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py index 96af9fbdeef..f2e35f9900b 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -3,13 +3,13 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError import core.app.apps.completion.app_generator as module from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py index 6cdcab29abe..cfe797aa764 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -1,7 +1,5 @@ from collections.abc import Generator -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus - from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( AppStreamResponse, @@ -12,6 +10,7 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus def test_convert_blocking_full_and_simple_response(): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py index 4fe82efcb33..9db83f5531e 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -1,5 +1,4 @@ import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult import core.app.apps.pipeline.pipeline_queue_manager as module from core.app.apps.base_app_queue_manager import PublishFrom @@ -14,6 +13,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, ) +from graphon.model_runtime.entities.llm_entities import LLMResult def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index c8ae288e6fe..618c8fd76f3 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -22,11 +22,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.graph_events import GraphRunFailedEvent import core.app.apps.pipeline.pipeline_runner as module from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from graphon.graph_events import GraphRunFailedEvent def _build_app_generate_entity() -> SimpleNamespace: diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index 6167be3bbdb..b0f8b423e1e 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,7 +1,7 @@ import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.apps.base_app_generator import BaseAppGenerator +from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -476,9 +476,8 @@ class TestBaseAppGeneratorExtras: assert converted[1] == "event: ping\n\n" def test_get_draft_var_saver_factory_debugger(self): - from graphon.enums import BuiltinNodeTypes - from core.app.entities.app_invoke_entities import InvokeFrom + from graphon.enums import BuiltinNodeTypes from models import Account base_app_generator = BaseAppGenerator() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py index 1dee7fdab66..17de39ca99f 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -4,15 +4,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessageRole, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError from core.app.app_config.entities import ( AdvancedChatMessageEntity, @@ -23,6 +14,15 @@ from core.app.app_config.entities import ( from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index a126bc85f75..6104b8d6ca2 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -4,9 +4,14 @@ from types import ModuleType, SimpleNamespace from typing import Any import graphon.nodes.human_input.entities # noqa: F401 +from core.app.apps.advanced_chat import app_generator as adv_app_gen_module +from core.app.apps.workflow import app_generator as wf_app_gen_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow import node_factory as node_factory_module +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from graphon.entities import WorkflowStartReason from graphon.entities.base_node_data import BaseNodeData, RetryConfig -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.entities.pause_reason import SchedulingPause from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus from graphon.graph import Graph @@ -25,12 +30,6 @@ from graphon.nodes.base.node import Node from graphon.nodes.end.entities import EndNodeData from graphon.nodes.start.entities import StartNodeData from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.apps.advanced_chat import app_generator as adv_app_gen_module -from core.app.apps.workflow import app_generator as wf_app_gen_module -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: @@ -56,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]): def version(cls) -> str: return "1" - def init_node_data(self, data): - self._node_data = _StubToolNodeData.model_validate(data) + def __init__( + self, + node_id: str, + config: _StubToolNodeData, + *, + graph_init_params, + graph_runtime_state, + **_kwargs: Any, + ) -> None: + super().__init__( + node_id=node_id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) def _get_error_strategy(self): return self._node_data.error_strategy @@ -90,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]): def _patch_tool_node(mocker): - original_create_node = DifyNodeFactory.create_node + original_resolve_node_class = node_factory_module.resolve_workflow_node_class - def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node: - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) - node_data = typed_node_config["data"] - if node_data.type == BuiltinNodeTypes.TOOL: - return _StubToolNode( - id=str(typed_node_config["id"]), - config=typed_node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return original_create_node(self, typed_node_config) + def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + if node_type == BuiltinNodeTypes.TOOL: + return _StubToolNode + return original_resolve_node_class(node_type=node_type, node_version=node_version) - mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) + mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class) def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index de5bca161c7..58c7bfa4bc2 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -4,6 +4,23 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest + +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueIterationCompletedEvent, + QueueLoopCompletedEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeSucceededEvent, + QueueTextChunkEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.system_variables import default_system_variables from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( @@ -24,23 +41,6 @@ from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.variables import StringVariable -from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.app.entities.queue_entities import ( - QueueAgentLogEvent, - QueueIterationCompletedEvent, - QueueLoopCompletedEvent, - QueueNodeExceptionEvent, - QueueNodeFailedEvent, - QueueNodeRetryEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowPausedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.system_variables import default_system_variables - class TestWorkflowBasedAppRunner: def test_resolve_user_from(self): diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index aa789d9ff3a..10fb2271f49 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -1,11 +1,11 @@ from unittest.mock import MagicMock import pytest -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events import GraphRunPausedEvent from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 9e30faecf23..620a153204d 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -4,14 +4,14 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.workflow.system_variables import default_system_variables +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.runtime import GraphRuntimeState, VariablePool from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index 8a717e1dccd..a3ab379b66d 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -3,11 +3,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events import GraphRunPausedEvent -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from core.app.apps.common import workflow_response_converter from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -16,6 +11,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.account import Account from models.human_input import RecipientType diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py index b768e813bd7..7dd7ffd7277 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -1,7 +1,5 @@ from collections.abc import Generator -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus - from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( ErrorStreamResponse, @@ -11,6 +9,7 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestWorkflowGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 29df903aa8b..1f6e7e12ef1 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -2,15 +2,14 @@ import time from contextlib import contextmanager from unittest.mock import MagicMock -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState - from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState from models.account import Account from models.model import AppMode from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index dabd2594b43..99433478d3e 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -2,10 +2,9 @@ from __future__ import annotations from contextlib import contextmanager from types import SimpleNamespace +from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline @@ -46,6 +45,8 @@ from core.app.entities.task_entities import ( ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import AppMode, EndUser @@ -610,33 +611,33 @@ class TestWorkflowGenerateTaskPipeline: def test_database_session_rolls_back_on_error(self, monkeypatch): pipeline = _make_pipeline() - calls = {"commit": 0, "rollback": 0} - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs + calls = {"enter": 0, "exit_exc": None} + class _BeginContext: def __enter__(self): - return self + calls["enter"] += 1 + return MagicMock() def __exit__(self, exc_type, exc, tb): + calls["exit_exc"] = exc_type return False - def commit(self): - calls["commit"] += 1 + class _Sessionmaker: + def __init__(self, *args, **kwargs): + pass - def rollback(self): - calls["rollback"] += 1 + def begin(self): + return _BeginContext() - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.sessionmaker", _Sessionmaker) monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) with pytest.raises(RuntimeError, match="db error"): with pipeline._database_session(): raise RuntimeError("db error") - assert calls["commit"] == 0 - assert calls["rollback"] == 1 + assert calls["enter"] == 1 + assert calls["exit_exc"] is RuntimeError def test_node_retry_and_started_handlers_cover_none_and_value(self): pipeline = _make_pipeline() diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py index 014a0cba729..7c797806411 100644 --- a/api/tests/unit_tests/core/app/entities/test_task_entities.py +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -1,11 +1,10 @@ -from graphon.enums import WorkflowNodeExecutionStatus - from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeRetryStreamResponse, NodeStartStreamResponse, StreamEvent, ) +from graphon.enums import WorkflowNodeExecutionStatus class TestTaskEntities: diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index a78c1b428fb..ba55e8f6950 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -1,6 +1,9 @@ from collections.abc import Sequence from unittest.mock import Mock +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.command_channels import CommandChannel from graphon.graph_events import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent @@ -8,10 +11,6 @@ from graphon.node_events import NodeRunResult from graphon.runtime import ReadOnlyGraphRuntimeState from graphon.variables import StringVariable from graphon.variables.segments import Segment, StringSegment - -from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 035e64325bb..539944d683d 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -4,6 +4,16 @@ from time import time from unittest.mock import Mock import pytest + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import ( + PauseStatePersistenceLayer, + WorkflowResumptionContext, + _AdvancedChatAppGenerateEntityWrapper, + _WorkflowGenerateEntityWrapper, +) +from core.workflow.system_variables import SystemVariableKey from graphon.entities.pause_reason import SchedulingPause from graphon.graph_engine.entities.commands import GraphEngineCommand from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError @@ -15,16 +25,6 @@ from graphon.graph_events import ( ) from graphon.runtime import ReadOnlyVariablePool from graphon.variables.segments import Segment - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity -from core.app.layers.pause_state_persist_layer import ( - PauseStatePersistenceLayer, - WorkflowResumptionContext, - _AdvancedChatAppGenerateEntityWrapper, - _WorkflowGenerateEntityWrapper, -) -from core.workflow.system_variables import SystemVariableKey from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py index 95931f4f8ba..12d49be0f1a 100644 --- a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py @@ -1,6 +1,5 @@ -from graphon.graph_events import GraphRunPausedEvent - from core.app.layers.suspend_layer import SuspendLayer +from graphon.graph_events import GraphRunPausedEvent class TestSuspendLayer: diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py index 7cf6eb4f310..1ac9a4d8c0c 100644 --- a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -1,8 +1,7 @@ from unittest.mock import Mock, patch -from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand - from core.app.layers.timeslice_layer import TimeSliceLayer +from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import SchedulerCommand diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py index aa9285789b3..d3bd15b6f34 100644 --- a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -2,11 +2,10 @@ from datetime import UTC, datetime, timedelta from types import SimpleNamespace from unittest.mock import Mock, patch -from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent -from graphon.runtime import VariablePool - from core.app.layers.trigger_post_layer import TriggerPostLayer from core.workflow.system_variables import build_system_variables +from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent +from graphon.runtime import VariablePool from models.enums import WorkflowTriggerStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py index 58aa7d74782..c246f7b7836 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.entities.queue_entities import QueueErrorEvent from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.errors.error import QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 4aaa10a81af..1c1bf391d3e 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -2,8 +2,6 @@ from types import SimpleNamespace from unittest.mock import ANY, Mock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity @@ -28,6 +26,8 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher from core.ops.ops_trace_manager import TraceQueueManager +from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py index f7e7b7e20ef..a20d89d8077 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -5,9 +5,6 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest -from graphon.file import FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from core.app.app_config.entities import ( AppAdditionalFeatures, @@ -41,6 +38,9 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AudioTrunk +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from models.model import AppMode @@ -505,13 +505,7 @@ class TestEasyUiBasedGenerateTaskPipeline: def __exit__(self, exc_type, exc, tb): return False - def query(self, *args, **kwargs): - return self - - def where(self, *args, **kwargs): - return self - - def first(self): + def scalar(self, *args, **kwargs): return agent_thought monkeypatch.setattr( @@ -1182,13 +1176,7 @@ class TestEasyUiBasedGenerateTaskPipeline: def __exit__(self, exc_type, exc, tb): return False - def query(self, *args, **kwargs): - return self - - def where(self, *args, **kwargs): - return self - - def first(self): + def scalar(self, *args, **kwargs): return None monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index 31b73130666..595d716666c 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -17,11 +17,11 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.file import FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from graphon.file import FileTransferMethod, FileType from models.model import MessageFile, UploadFile diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index 07ee75ed35f..92fe3cbec67 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -9,7 +9,7 @@ from flask import Flask, current_app from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueRetrieverResourcesEvent from core.app.entities.task_entities import MessageStreamResponse, StreamEvent, TaskStateMetadata from core.app.task_pipeline.message_cycle_manager import MessageCycleManager -from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.rag.entities import RetrievalSourceMetadata from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py index 29df7eea863..21c761c579b 100644 --- a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py +++ b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py @@ -1,10 +1,9 @@ from types import SimpleNamespace from unittest.mock import patch -from graphon.model_runtime.entities.model_entities import ModelPropertyKey - from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager from core.app.app_config.entities import ModelConfigEntity +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from models.provider_ids import ModelProviderID diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py index dc2d82ccd6c..5c50cb78dae 100644 --- a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -2,14 +2,14 @@ from datetime import UTC, datetime from unittest.mock import Mock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType -from graphon.node_events import NodeRunResult from core.app.workflow.layers.persistence import ( PersistenceWorkflowInfo, WorkflowPersistenceLayer, _NodeRuntimeSnapshot, ) +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType +from graphon.node_events import NodeRunResult def _build_layer() -> WorkflowPersistenceLayer: diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py index 7be9d6ac1eb..701863b9272 100644 --- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -8,13 +8,13 @@ from unittest.mock import MagicMock, patch from urllib.parse import parse_qs, urlparse import pytest -from graphon.file import File, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope from core.app.workflow import file_runtime from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime from core.workflow.file_reference import build_file_reference +from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile @@ -26,8 +26,8 @@ def _build_file( extension: str | None = None, ) -> File: return File( - id="file-id", - type=FileType.IMAGE, + file_id="file-id", + file_type=FileType.IMAGE, transfer_method=transfer_method, reference=reference, remote_url=remote_url, @@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M assert runtime.multimodal_send_format == "url" - with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get: + with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get: assert runtime.http_get("http://example", follow_redirects=False) == "response" mock_get.assert_called_once_with("http://example", follow_redirects=False) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index 8497261d453..30a068f4c5a 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -1,15 +1,15 @@ from types import SimpleNamespace import pytest -from graphon.enums import BuiltinNodeTypes from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.workflow.node_factory import DifyNodeFactory +from graphon.enums import BuiltinNodeTypes class DummyNode: - def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs): - self.id = id + def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs): + self.id = node_id self.config = config self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state diff --git a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py index a47d3db6f5b..82552470a95 100644 --- a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py +++ b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py @@ -2,9 +2,8 @@ from __future__ import annotations from types import SimpleNamespace -from graphon.enums import BuiltinNodeTypes - from core.app.workflow.layers.observability import ObservabilityLayer +from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerExtras: diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index d8a68f6d000..cacb4dd4fa9 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -4,6 +4,10 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest + +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.system_variables import SystemVariableKey, build_system_variables from graphon.entities import WorkflowNodeExecution from graphon.entities.pause_reason import SchedulingPause from graphon.enums import ( @@ -29,10 +33,6 @@ from graphon.graph_events import ( from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity -from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.workflow.system_variables import SystemVariableKey, build_system_variables - class _RepoRecorder: def __init__(self) -> None: diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py index 5ff9774b525..7b433ab57b9 100644 --- a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -301,6 +301,7 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() + from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -308,8 +309,6 @@ class TestAppGeneratorTTSPublisher: TextPromptMessageContent, ) - from core.app.entities.queue_entities import QueueAgentMessageEvent - chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( @@ -337,11 +336,10 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() + from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import AssistantPromptMessage - from core.app.entities.queue_entities import QueueAgentMessageEvent - chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index b0c72ee42f5..deeac49bbc2 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -2,15 +2,15 @@ import types from collections.abc import Generator import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType -from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError from core.workflow.file_reference import parse_file_reference +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: @@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker): mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE) built = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tool_file_1", extension=".png", @@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) file_in = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tf", extension=".pdf", @@ -632,16 +632,6 @@ def test_get_upload_file_by_id_builds_file(mocker): source_url="http://x", ) - class _Q: - def __init__(self, row): - self._row = row - - def where(self, *_args, **_kwargs): - return self - - def first(self): - return self._row - class _S: def __init__(self, row): self._row = row @@ -652,8 +642,8 @@ def test_get_upload_file_by_id_builds_file(mocker): def __exit__(self, *exc): return False - def query(self, *_): - return _Q(self._row) + def scalar(self, *_args, **_kwargs): + return self._row mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row)) @@ -665,13 +655,6 @@ def test_get_upload_file_by_id_builds_file(mocker): def test_get_upload_file_by_id_raises_when_missing(mocker): - class _Q: - def where(self, *_args, **_kwargs): - return self - - def first(self): - return None - class _S: def __enter__(self): return self @@ -679,8 +662,8 @@ def test_get_upload_file_by_id_raises_when_missing(mocker): def __exit__(self, *exc): return False - def query(self, *_): - return _Q() + def scalar(self, *_args, **_kwargs): + return None mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S()) diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py index fbaf6d497d7..0fca43cd0b1 100644 --- a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -1,10 +1,10 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.file import File, FileTransferMethod, FileType from core.datasource.entities.datasource_entities import DatasourceMessage from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer +from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py index ff9fd0d8f3f..ef8f360dbfe 100644 --- a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -1,12 +1,11 @@ -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType - from core.entities.execution_extra_content import ( ExecutionExtraContentDomainModel, HumanInputContent, HumanInputFormDefinition, HumanInputFormSubmissionData, ) +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.execution_extra_content import ExecutionContentType diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index 2acd278a31e..aeca2e3afdd 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -8,9 +8,6 @@ drive provider mapping behavior. """ import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.entities.model_entities import ( DefaultModelEntity, @@ -19,6 +16,9 @@ from core.entities.model_entities import ( ProviderModelWithStatusEntity, SimpleModelProviderEntity, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: @@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None: # Assert assert simple_provider.provider == "openai" - assert simple_provider.label.en_US == "OpenAI" + assert simple_provider.label.en_us == "OpenAI" assert simple_provider.supported_model_types == [ModelType.LLM] diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 8cf0409c4c2..a28143026ff 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -6,17 +6,6 @@ from typing import Any from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FieldModelSchema, - FormType, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderEntity, -) from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus @@ -35,6 +24,17 @@ from core.entities.provider_entities import ( SystemConfiguration, SystemConfigurationStatus, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) from models.enums import CredentialSourceType from models.provider import ProviderType from models.provider_ids import ModelProviderID @@ -345,22 +345,26 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) ] ) - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="encrypted-old-key" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): - with patch( - "core.entities.provider_configuration.encrypter.encrypt_token", - side_effect=lambda tenant_id, value: f"enc::{value}", - ): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, - credential_id="credential-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): + with patch( + "core.entities.provider_configuration.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc::{value}", + ): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, + credential_id="credential-1", + ) assert validated["openai_api_key"] == "enc::restored-key" assert validated["region"] == "us" @@ -370,23 +374,15 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) -def test_validate_provider_credentials_opens_session_when_not_passed() -> None: +def test_validate_provider_credentials_without_credential_id() -> None: configuration = _build_provider_configuration() - mock_session = Mock() mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"region": "us"} - with patch("core.entities.provider_configuration.Session") as mock_session_cls: - with patch("core.entities.provider_configuration.db") as mock_db: - mock_db.engine = Mock() - mock_session_cls.return_value.__enter__.return_value = mock_session - with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory - ): - validated = configuration.validate_provider_credentials(credentials={"region": "us"}) + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): + validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} - mock_session_cls.assert_called_once() def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None: @@ -717,18 +713,22 @@ def test_check_provider_credential_name_exists_and_model_setting_lookup() -> Non def test_validate_provider_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-key"} @@ -1060,37 +1060,35 @@ def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback( def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( encrypted_config='{"openai_api_key":"enc"}' ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) - assert validated == {"openai_api_key": "enc-new"} - - session = Mock() - mock_factory = Mock() - mock_factory.model_credentials_validate.return_value = {"region": "us"} - with _patched_session(session): + with _patched_session(mock_session): with patch( "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory ): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"region": "us"}, - ) + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) + assert validated == {"openai_api_key": "enc-new"} + + mock_factory2 = Mock() + mock_factory2.model_credentials_validate.return_value = {"region": "us"} + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"region": "us"}, + ) assert validated == {"region": "us"} @@ -1570,18 +1568,20 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None: def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = None + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = None mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} @@ -1692,20 +1692,24 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None: def test_validate_custom_model_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py index 8685d162831..a159d3ad4d0 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -1,5 +1,4 @@ import pytest -from graphon.model_runtime.entities.model_entities import ModelType from core.entities.parameter_entities import AppSelectorScope from core.entities.provider_entities import ( @@ -9,6 +8,7 @@ from core.entities.provider_entities import ( ProviderQuotaType, ) from core.tools.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType def test_provider_quota_type_value_of_returns_enum_member() -> None: diff --git a/api/tests/unit_tests/core/external_data_tool/test_base.py b/api/tests/unit_tests/core/external_data_tool/test_base.py index 216cda83c56..63e887f904e 100644 --- a/api/tests/unit_tests/core/external_data_tool/test_base.py +++ b/api/tests/unit_tests/core/external_data_tool/test_base.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from core.extension.extensible import ExtensionModule @@ -12,10 +14,10 @@ class TestExternalDataTool: # Create a concrete subclass to test init class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): return super().validate_config(tenant_id, config) - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return super().query(inputs, query) tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"}) @@ -28,10 +30,10 @@ class TestExternalDataTool: # Create a concrete subclass to test init class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): pass - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return "" tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") @@ -43,10 +45,10 @@ class TestExternalDataTool: def test_validate_config_raises_not_implemented(self): class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): return super().validate_config(tenant_id, config) - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return "" with pytest.raises(NotImplementedError): @@ -55,10 +57,10 @@ class TestExternalDataTool: def test_query_raises_not_implemented(self): class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): pass - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return super().query(inputs, query) tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index bb6e40e2247..8cb09385755 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType def test_file(): file = File( - id="test-file", + file_id="test-file", tenant_id="test-tenant-id", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="test-related-id", filename="image.png", @@ -25,27 +25,21 @@ def test_file(): assert file.size == 67 -def test_file_model_validate_accepts_legacy_tenant_id(): - data = { - "id": "test-file", - "tenant_id": "test-tenant-id", - "type": "image", - "transfer_method": "tool_file", - "related_id": "test-related-id", - "filename": "image.png", - "extension": ".png", - "mime_type": "image/png", - "size": 67, - "storage_key": "test-storage-key", - "url": "https://example.com/image.png", - # Extra legacy fields - "tool_file_id": "tool-file-123", - "upload_file_id": "upload-file-456", - "datasource_file_id": "datasource-file-789", - } +def test_file_constructor_accepts_legacy_tenant_id(): + file = File( + file_id="test-file", + tenant_id="test-tenant-id", + file_type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + tool_file_id="tool-file-123", + filename="image.png", + extension=".png", + mime_type="image/png", + size=67, + storage_key="test-storage-key", + url="https://example.com/image.png", + ) - file = File.model_validate(data) - - assert file.related_id == "test-related-id" + assert file.related_id == "tool-file-123" assert file.storage_key == "test-storage-key" assert "tenant_id" not in file.model_dump() diff --git a/api/tests/unit_tests/core/helper/code_executor/jinja2/test_jinja2_formatter.py b/api/tests/unit_tests/core/helper/code_executor/jinja2/test_jinja2_formatter.py new file mode 100644 index 00000000000..60002a757d5 --- /dev/null +++ b/api/tests/unit_tests/core/helper/code_executor/jinja2/test_jinja2_formatter.py @@ -0,0 +1,24 @@ +from pytest_mock import MockerFixture + +from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter + + +def test_format_returns_result_value_as_string(mocker: MockerFixture) -> None: + execute_mock = mocker.patch( + "core.helper.code_executor.jinja2.jinja2_formatter.CodeExecutor.execute_workflow_code_template", + return_value={"result": 123}, + ) + + formatted = Jinja2Formatter.format("Hello {{ name }}", {"name": "Dify"}) + + assert formatted == "123" + execute_mock.assert_called_once() + + +def test_format_returns_empty_string_when_result_missing(mocker: MockerFixture) -> None: + mocker.patch( + "core.helper.code_executor.jinja2.jinja2_formatter.CodeExecutor.execute_workflow_code_template", + return_value={}, + ) + + assert Jinja2Formatter.format("Hello", {"name": "Dify"}) == "" diff --git a/api/tests/unit_tests/core/helper/code_executor/test_code_executor.py b/api/tests/unit_tests/core/helper/code_executor/test_code_executor.py new file mode 100644 index 00000000000..e09dd03489d --- /dev/null +++ b/api/tests/unit_tests/core/helper/code_executor/test_code_executor.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from core.helper.code_executor import code_executor as code_executor_module + + +def test_execute_workflow_code_template_raises_for_unsupported_language() -> None: + with pytest.raises(code_executor_module.CodeExecutionError, match="Unsupported language"): + code_executor_module.CodeExecutor.execute_workflow_code_template(cast(Any, "ruby"), "print(1)", {}) + + +def test_execute_workflow_code_template_uses_transformer(mocker: MockerFixture) -> None: + transformer = MagicMock() + transformer.transform_caller.return_value = ("runner-script", "preload-script") + transformer.transform_response.return_value = {"result": "ok"} + execute_mock = mocker.patch.object( + code_executor_module.CodeExecutor, + "execute_code", + return_value='<
{qa.question}
{qa.answer}
{{ comment_content }}
+