diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md
index 280fcb6341..69c099a262 100644
--- a/.agents/skills/frontend-testing/SKILL.md
+++ b/.agents/skills/frontend-testing/SKILL.md
@@ -204,6 +204,16 @@ When assigned to test a directory/path, test **ALL content** within that path:
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
+### `nuqs` Query State Testing (Required for URL State Hooks)
+
+When a component or hook uses `useQueryState` / `useQueryStates`:
+
+- ✅ Use `NuqsTestingAdapter` (prefer shared helpers in `web/test/nuqs-testing.tsx`)
+- ✅ Assert URL synchronization via `onUrlUpdate` (`searchParams`, `options.history`)
+- ✅ For custom parsers (`createParser`), keep `parse` and `serialize` bijective and add round-trip edge cases (`%2F`, `%25`, spaces, legacy encoded values)
+- ✅ Verify default-clearing behavior (default values should be removed from URL when applicable)
+- ⚠️ Only mock `nuqs` directly when URL behavior is explicitly out of scope for the test
+
## Core Principles
### 1. AAA Pattern (Arrange-Act-Assert)
diff --git a/.agents/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md
index 1ff2b27bbb..10b8fb66f9 100644
--- a/.agents/skills/frontend-testing/references/checklist.md
+++ b/.agents/skills/frontend-testing/references/checklist.md
@@ -80,6 +80,9 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
- [ ] Router mocks match actual Next.js API
- [ ] Mocks reflect actual component conditional behavior
- [ ] Only mock: API services, complex context providers, third-party libs
+- [ ] For `nuqs` URL-state tests, wrap with `NuqsTestingAdapter` (prefer `web/test/nuqs-testing.tsx`)
+- [ ] For `nuqs` URL-state tests, assert `onUrlUpdate` payload (`searchParams`, `options.history`)
+- [ ] If custom `nuqs` parser exists, add round-trip tests for encoded edge cases (`%2F`, `%25`, spaces, legacy encoded values)
### Queries
diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md
index 86bd375987..f58377c4a5 100644
--- a/.agents/skills/frontend-testing/references/mocking.md
+++ b/.agents/skills/frontend-testing/references/mocking.md
@@ -125,6 +125,31 @@ describe('Component', () => {
})
```
+### 2.1 `nuqs` Query State (Preferred: Testing Adapter)
+
+For tests that validate URL query behavior, use `NuqsTestingAdapter` instead of mocking `nuqs` directly.
+
+```typescript
+import { renderHookWithNuqs } from '@/test/nuqs-testing'
+
+it('should sync query to URL with push history', async () => {
+ const { result, onUrlUpdate } = renderHookWithNuqs(() => useMyQueryState(), {
+ searchParams: '?page=1',
+ })
+
+ act(() => {
+ result.current.setQuery({ page: 2 })
+ })
+
+ await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
+ const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
+ expect(update.options.history).toBe('push')
+ expect(update.searchParams.get('page')).toBe('2')
+})
+```
+
+Use direct `vi.mock('nuqs')` only when URL synchronization is intentionally out of scope.
+
### 3. Portal Components (with Shared State)
```typescript
diff --git a/.agents/skills/orpc-contract-first/SKILL.md b/.agents/skills/orpc-contract-first/SKILL.md
index 4e3bfc7a37..b5cd62dfb5 100644
--- a/.agents/skills/orpc-contract-first/SKILL.md
+++ b/.agents/skills/orpc-contract-first/SKILL.md
@@ -1,43 +1,100 @@
---
name: orpc-contract-first
-description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories.
+description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Trigger when creating or updating contracts in web/contract, wiring router composition, integrating TanStack Query with typed contracts, migrating legacy service calls to oRPC, or deciding whether to call queryOptions directly vs extracting a helper or use-* hook in web/service.
---
# oRPC Contract-First Development
-## Project Structure
+## Intent
-```
+- Keep contract as single source of truth in `web/contract/*`.
+- Default query usage: call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
+- Keep abstractions minimal and preserve TypeScript inference.
+
+## Minimal Structure
+
+```text
web/contract/
-├── base.ts # Base contract (inputStructure: 'detailed')
-├── router.ts # Router composition & type exports
-├── marketplace.ts # Marketplace contracts
-└── console/ # Console contracts by domain
- ├── system.ts
- └── billing.ts
+├── base.ts
+├── router.ts
+├── marketplace.ts
+└── console/
+ ├── billing.ts
+ └── ...other domains
+web/service/client.ts
```
-## Workflow
+## Core Workflow
-1. **Create contract** in `web/contract/console/{domain}.ts`
- - Import `base` from `../base` and `type` from `@orpc/contract`
- - Define route with `path`, `method`, `input`, `output`
+1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`
+ - Use `base.route({...}).output(type<...>())` as baseline.
+ - Add `.input(type<...>())` only when request has `params/query/body`.
+ - For `GET` without input, omit `.input(...)` (do not use `.input(type
Hello
", encoding="utf-8") + + extractor = HtmlExtractor(str(file_path)) + docs = extractor.extract() + + assert len(docs) == 1 + assert "".join(docs[0].page_content.split()) == "TitleHello" + + def test_load_as_text_strips_whitespace_and_handles_empty(self, tmp_path): + file_path = tmp_path / "sample.html" + file_path.write_text(" \n ", encoding="utf-8") + + extractor = HtmlExtractor(str(file_path)) + + assert extractor._load_as_text() == "" diff --git a/api/tests/unit_tests/core/rag/extractor/test_jina_reader_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_jina_reader_extractor.py new file mode 100644 index 0000000000..0b4c9bd809 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_jina_reader_extractor.py @@ -0,0 +1,47 @@ +from pytest_mock import MockerFixture + +from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor + + +class TestJinaReaderWebExtractor: + def test_extract_crawl_mode_returns_document(self, mocker: MockerFixture): + mocker.patch( + "core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data", + return_value={ + "content": "markdown-content", + "url": "https://example.com", + "description": "desc", + "title": "title", + }, + ) + + extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "markdown-content" + assert docs[0].metadata == { + "source_url": "https://example.com", + "description": "desc", + "title": "title", + } + + def test_extract_crawl_mode_with_missing_data_returns_empty(self, mocker: MockerFixture): + mocker.patch( + "core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data", + return_value=None, + ) + + extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + + assert extractor.extract() == [] + + def test_extract_non_crawl_mode_returns_empty(self, mocker: MockerFixture): + mock_get_crawl = mocker.patch( + "core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data", + return_value={"content": "unused"}, + ) + extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="scrape") + + assert extractor.extract() == [] + mock_get_crawl.assert_not_called() diff --git a/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py index d4cf534c56..7e78c86c7d 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py @@ -1,8 +1,15 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.markdown_extractor as markdown_module from core.rag.extractor.markdown_extractor import MarkdownExtractor -def test_markdown_to_tups(): - markdown = """ +class TestMarkdownExtractor: + def test_markdown_to_tups(self): + markdown = """ this is some text without header # title 1 @@ -11,12 +18,113 @@ this is balabala text ## title 2 this is more specific text. """ - extractor = MarkdownExtractor(file_path="dummy_path") - updated_output = extractor.markdown_to_tups(markdown) - assert len(updated_output) == 3 - key, header_value = updated_output[0] - assert key == None - assert header_value.strip() == "this is some text without header" - title_1, value = updated_output[1] - assert title_1.strip() == "title 1" - assert value.strip() == "this is balabala text" + extractor = MarkdownExtractor(file_path="dummy_path") + updated_output = extractor.markdown_to_tups(markdown) + + assert len(updated_output) == 3 + key, header_value = updated_output[0] + assert key is None + assert header_value.strip() == "this is some text without header" + + title_1, value = updated_output[1] + assert title_1.strip() == "title 1" + assert value.strip() == "this is balabala text" + + def test_markdown_to_tups_keeps_code_block_headers_literal(self): + markdown = """# Header +before +```python +# this is not a heading +print('x') +``` +after +""" + extractor = MarkdownExtractor(file_path="dummy_path") + + tups = extractor.markdown_to_tups(markdown) + + assert len(tups) == 2 + assert tups[1][0] == "Header" + assert "# this is not a heading" in tups[1][1] + + def test_remove_images_and_hyperlinks(self): + extractor = MarkdownExtractor(file_path="dummy_path") + + with_images = "before ![[image.png]] after" + with_links = "[OpenAI](https://openai.com)" + + assert extractor.remove_images(with_images) == "before after" + assert extractor.remove_hyperlinks(with_links) == "OpenAI" + + def test_parse_tups_reads_file_and_applies_options(self, tmp_path): + markdown_file = tmp_path / "doc.md" + markdown_file.write_text("# Header\nText with [link](https://example.com) and ![[img.png]]", encoding="utf-8") + + extractor = MarkdownExtractor( + file_path=str(markdown_file), + remove_hyperlinks=True, + remove_images=True, + autodetect_encoding=False, + ) + + tups = extractor.parse_tups(str(markdown_file)) + + assert len(tups) == 2 + assert tups[1][0] == "Header" + assert "[link]" not in tups[1][1] + assert "img.png" not in tups[1][1] + + def test_parse_tups_autodetects_encoding_after_decode_error(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path", autodetect_encoding=True) + + calls: list[str | None] = [] + + def fake_read_text(self, encoding=None): + calls.append(encoding) + if encoding is None: + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail") + if encoding == "bad-encoding": + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail") + return "# H\ncontent" + + monkeypatch.setattr(Path, "read_text", fake_read_text, raising=True) + monkeypatch.setattr( + markdown_module, + "detect_file_encodings", + lambda _: [SimpleNamespace(encoding="bad-encoding"), SimpleNamespace(encoding="utf-8")], + ) + + tups = extractor.parse_tups("dummy_path") + + assert len(tups) == 2 + assert calls == [None, "bad-encoding", "utf-8"] + + def test_parse_tups_decode_error_with_autodetect_disabled_raises(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path", autodetect_encoding=False) + + def raise_decode(self, encoding=None): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail") + + monkeypatch.setattr(Path, "read_text", raise_decode, raising=True) + + with pytest.raises(RuntimeError, match="Error loading dummy_path"): + extractor.parse_tups("dummy_path") + + def test_parse_tups_other_exceptions_are_wrapped(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path") + + def raise_other(self, encoding=None): + raise OSError("disk error") + + monkeypatch.setattr(Path, "read_text", raise_other, raising=True) + + with pytest.raises(RuntimeError, match="Error loading dummy_path"): + extractor.parse_tups("dummy_path") + + def test_extract_builds_documents_for_header_and_non_header(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path") + monkeypatch.setattr(extractor, "parse_tups", lambda _: [(None, "plain"), ("Header", "value")]) + + docs = extractor.extract() + + assert [doc.page_content for doc in docs] == ["plain", "\n\nHeader\nvalue"] diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py index 58bec7d19e..6daee11f8f 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -1,93 +1,499 @@ +from types import SimpleNamespace from unittest import mock +import httpx +import pytest from pytest_mock import MockerFixture from core.rag.extractor import notion_extractor -user_id = "user1" -database_id = "database1" -page_id = "page1" - -extractor = notion_extractor.NotionExtractor( - notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x" -) - - -def _generate_page(page_title: str): - return { - "object": "page", - "id": page_id, - "properties": { - "Page": { - "type": "title", - "title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}], - } - }, - } - - -def _generate_block(block_id: str, block_type: str, block_text: str): - return { - "object": "block", - "id": block_id, - "parent": {"type": "page_id", "page_id": page_id}, - "type": block_type, - "has_children": False, - block_type: { - "rich_text": [ - { - "type": "text", - "text": {"content": block_text}, - "plain_text": block_text, - } - ] - }, - } - - -def _mock_response(data): +def _mock_response(data, status_code: int = 200, text: str = ""): response = mock.Mock() - response.status_code = 200 + response.status_code = status_code + response.text = text response.json.return_value = data return response -def _remove_multiple_new_lines(text): - while "\n\n" in text: - text = text.replace("\n\n", "\n") - return text.strip() +class TestNotionExtractorInitAndPublicMethods: + def test_init_with_explicit_token(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + assert extractor._notion_access_token == "token" + + def test_init_falls_back_to_env_token_when_credential_lookup_fails(self, monkeypatch): + monkeypatch.setattr( + notion_extractor.NotionExtractor, + "_get_access_token", + classmethod(lambda cls, tenant_id, credential_id: (_ for _ in ()).throw(Exception("credential error"))), + ) + monkeypatch.setattr(notion_extractor.dify_config, "NOTION_INTEGRATION_TOKEN", "env-token", raising=False) + + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + credential_id="cred", + ) + + assert extractor._notion_access_token == "env-token" + + def test_init_raises_if_no_credential_and_no_env_token(self, monkeypatch): + monkeypatch.setattr( + notion_extractor.NotionExtractor, + "_get_access_token", + classmethod(lambda cls, tenant_id, credential_id: (_ for _ in ()).throw(Exception("credential error"))), + ) + monkeypatch.setattr(notion_extractor.dify_config, "NOTION_INTEGRATION_TOKEN", None, raising=False) + + with pytest.raises(ValueError, match="Must specify `integration_token`"): + notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + credential_id="cred", + ) + + def test_extract_updates_last_edited_and_loads_documents(self, monkeypatch): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + update_mock = mock.Mock() + load_mock = mock.Mock(return_value=[SimpleNamespace(page_content="doc")]) + monkeypatch.setattr(extractor, "update_last_edited_time", update_mock) + monkeypatch.setattr(extractor, "_load_data_as_documents", load_mock) + + docs = extractor.extract() + + update_mock.assert_called_once_with(None) + load_mock.assert_called_once_with("obj", "page") + assert len(docs) == 1 + + def test_load_data_as_documents_page_database_and_invalid(self, monkeypatch): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + monkeypatch.setattr(extractor, "_get_notion_block_data", lambda _: ["line1", "line2"]) + page_docs = extractor._load_data_as_documents("page-id", "page") + assert page_docs[0].page_content == "line1\nline2" + + monkeypatch.setattr(extractor, "_get_notion_database_data", lambda _: [SimpleNamespace(page_content="db")]) + db_docs = extractor._load_data_as_documents("db-id", "database") + assert db_docs[0].page_content == "db" + + with pytest.raises(ValueError, match="notion page type not supported"): + extractor._load_data_as_documents("obj", "unsupported") -def test_notion_page(mocker: MockerFixture): - texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"] - mocked_notion_page = { - "object": "list", - "results": [ - _generate_block("b1", "heading_1", texts[0]), - _generate_block("b2", "heading_2", texts[1]), - _generate_block("b3", "paragraph", texts[2]), - _generate_block("b4", "heading_3", texts[3]), - ], - "next_cursor": None, - } - mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page)) +class TestNotionDatabase: + def test_get_notion_database_data_parses_property_types_and_pagination(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) - page_docs = extractor._load_data_as_documents(page_id, "page") - assert len(page_docs) == 1 - content = _remove_multiple_new_lines(page_docs[0].page_content) - assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1" + first_page = { + "results": [ + { + "properties": { + "tags": { + "type": "multi_select", + "multi_select": [{"name": "A"}, {"name": "B"}], + }, + "title_prop": {"type": "title", "title": [{"plain_text": "Title"}]}, + "empty_title": {"type": "title", "title": []}, + "rich": {"type": "rich_text", "rich_text": [{"plain_text": "RichText"}]}, + "empty_rich": {"type": "rich_text", "rich_text": []}, + "select_prop": {"type": "select", "select": {"name": "Selected"}}, + "empty_select": {"type": "select", "select": None}, + "status_prop": {"type": "status", "status": {"name": "Open"}}, + "empty_status": {"type": "status", "status": None}, + "number_prop": {"type": "number", "number": 10}, + "dict_prop": {"type": "date", "date": {"start": "2024-01-01", "end": None}}, + }, + "url": "https://notion.so/page-1", + } + ], + "has_more": True, + "next_cursor": "cursor-2", + } + second_page = {"results": [], "has_more": False, "next_cursor": None} + + mock_post = mocker.patch("httpx.post", side_effect=[_mock_response(first_page), _mock_response(second_page)]) + + docs = extractor._get_notion_database_data("db-1", query_dict={"filter": {"x": 1}}) + + assert len(docs) == 1 + content = docs[0].page_content + assert "tags:['A', 'B']" in content + assert "title_prop:Title" in content + assert "rich:RichText" in content + assert "number_prop:10" in content + assert "dict_prop:start:2024-01-01" in content + assert "Row Page URL:https://notion.so/page-1" in content + assert mock_post.call_count == 2 + + def test_get_notion_database_data_handles_missing_results_and_empty_content(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) + + mocker.patch("httpx.post", return_value=_mock_response({"results": None})) + assert extractor._get_notion_database_data("db-1") == [] + + def test_get_notion_database_data_requires_access_token(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) + extractor._notion_access_token = None + + with pytest.raises(AssertionError, match="Notion access token is required"): + extractor._get_notion_database_data("db-1") -def test_notion_database(mocker: MockerFixture): - page_title_list = ["page1", "page2", "page3"] - mocked_notion_database = { - "object": "list", - "results": [_generate_page(i) for i in page_title_list], - "next_cursor": None, - } - mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database)) - database_docs = extractor._load_data_as_documents(database_id, "database") - assert len(database_docs) == 1 - content = _remove_multiple_new_lines(database_docs[0].page_content) - assert content == "\n".join([f"Page:{i}" for i in page_title_list]) +class TestNotionBlocks: + def test_get_notion_block_data_success_with_table_headings_children_and_pagination(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + first_response = { + "results": [ + {"type": "table", "id": "tbl-1", "has_children": False, "table": {}}, + { + "type": "heading_1", + "id": "h1", + "has_children": False, + "heading_1": {"rich_text": [{"text": {"content": "Heading"}}]}, + }, + { + "type": "paragraph", + "id": "p1", + "has_children": True, + "paragraph": {"rich_text": [{"text": {"content": "Paragraph"}}]}, + }, + { + "type": "child_page", + "id": "cp1", + "has_children": True, + "child_page": {"rich_text": []}, + }, + ], + "next_cursor": "cursor-2", + } + second_response = { + "results": [ + { + "type": "heading_2", + "id": "h2", + "has_children": False, + "heading_2": {"rich_text": [{"text": {"content": "SubHeading"}}]}, + } + ], + "next_cursor": None, + } + + mocker.patch("httpx.request", side_effect=[_mock_response(first_response), _mock_response(second_response)]) + mocker.patch.object(extractor, "_read_table_rows", return_value="TABLE") + mocker.patch.object(extractor, "_read_block", return_value="CHILD") + + lines = extractor._get_notion_block_data("page-1") + + assert lines[0] == "TABLE\n\n" + assert "# Heading" in lines[1] + assert "Paragraph\nCHILD\n\n" in lines[2] + assert "## SubHeading" in lines[-1] + + def test_get_notion_block_data_handles_http_error_and_invalid_payload(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + mocker.patch("httpx.request", side_effect=httpx.HTTPError("network")) + with pytest.raises(ValueError, match="Error fetching Notion block data"): + extractor._get_notion_block_data("page-1") + + mocker.patch("httpx.request", return_value=_mock_response({"bad": "payload"}, status_code=200)) + with pytest.raises(ValueError, match="Error fetching Notion block data"): + extractor._get_notion_block_data("page-1") + + mocker.patch("httpx.request", return_value=_mock_response({"results": []}, status_code=500, text="boom")) + with pytest.raises(ValueError, match="Error fetching Notion block data: boom"): + extractor._get_notion_block_data("page-1") + + def test_read_block_supports_heading_table_and_recursion(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + root_payload = { + "results": [ + { + "type": "heading_2", + "id": "h2", + "has_children": False, + "heading_2": {"rich_text": [{"text": {"content": "Root"}}]}, + }, + { + "type": "paragraph", + "id": "child-block", + "has_children": True, + "paragraph": {"rich_text": [{"text": {"content": "Parent"}}]}, + }, + {"type": "table", "id": "tbl-1", "has_children": False, "table": {}}, + ], + "next_cursor": None, + } + child_payload = { + "results": [ + { + "type": "paragraph", + "id": "leaf", + "has_children": False, + "paragraph": {"rich_text": [{"text": {"content": "Child"}}]}, + } + ], + "next_cursor": None, + } + + mocker.patch("httpx.request", side_effect=[_mock_response(root_payload), _mock_response(child_payload)]) + mocker.patch.object(extractor, "_read_table_rows", return_value="TABLE-MD") + + content = extractor._read_block("root") + + assert "## Root" in content + assert "Parent" in content + assert "Child" in content + assert "TABLE-MD" in content + + def test_read_block_breaks_on_missing_results(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + mocker.patch("httpx.request", return_value=_mock_response({"results": None, "next_cursor": None})) + + assert extractor._read_block("root") == "" + + def test_read_table_rows_formats_markdown_with_pagination(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + page_one = { + "results": [ + { + "table_row": { + "cells": [ + [{"text": {"content": "H1"}}], + [{"text": {"content": "H2"}}], + ] + } + }, + { + "table_row": { + "cells": [ + [{"text": {"content": "R1C1"}}], + [{"text": {"content": "R1C2"}}], + ] + } + }, + ], + "next_cursor": "next", + } + page_two = { + "results": [ + { + "table_row": { + "cells": [ + [{"text": {"content": "H1"}}], + [], + ] + } + }, + { + "table_row": { + "cells": [ + [{"text": {"content": "R2C1"}}], + [{"text": {"content": "R2C2"}}], + ] + } + }, + ], + "next_cursor": None, + } + + mocker.patch("httpx.request", side_effect=[_mock_response(page_one), _mock_response(page_two)]) + + markdown = extractor._read_table_rows("tbl-1") + + assert "| H1 | H2 |" in markdown + assert "| R1C1 | R1C2 |" in markdown + assert "| H1 | |" in markdown + assert "| R2C1 | R2C2 |" in markdown + + +class TestNotionMetadataAndCredentialMethods: + def test_update_last_edited_time_no_document_model(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + assert extractor.update_last_edited_time(None) is None + + def test_update_last_edited_time_updates_document_and_commits(self, monkeypatch): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + class FakeDocumentModel: + data_source_info = "data_source_info" + + update_calls = [] + + class FakeQuery: + def filter_by(self, **kwargs): + return self + + def update(self, payload): + update_calls.append(payload) + + class FakeSession: + committed = False + + def query(self, model): + assert model is FakeDocumentModel + return FakeQuery() + + def commit(self): + self.committed = True + + fake_db = SimpleNamespace(session=FakeSession()) + monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel) + monkeypatch.setattr(notion_extractor, "db", fake_db) + monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z") + + doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"}) + extractor.update_last_edited_time(doc_model) + + assert update_calls + assert fake_db.session.committed is True + + def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture): + extractor_page = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="page-id", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + request_mock = mocker.patch( + "httpx.request", return_value=_mock_response({"last_edited_time": "2025-05-01T00:00:00.000Z"}) + ) + + assert extractor_page.get_notion_last_edited_time() == "2025-05-01T00:00:00.000Z" + assert "pages/page-id" in request_mock.call_args[0][1] + + extractor_db = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="db-id", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) + request_mock = mocker.patch( + "httpx.request", return_value=_mock_response({"last_edited_time": "2025-06-01T00:00:00.000Z"}) + ) + + assert extractor_db.get_notion_last_edited_time() == "2025-06-01T00:00:00.000Z" + assert "databases/db-id" in request_mock.call_args[0][1] + + def test_get_notion_last_edited_time_requires_access_token(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + extractor._notion_access_token = None + + with pytest.raises(AssertionError, match="Notion access token is required"): + extractor.get_notion_last_edited_time() + + def test_get_access_token_success_and_errors(self, monkeypatch): + with pytest.raises(Exception, match="No credential id found"): + notion_extractor.NotionExtractor._get_access_token("tenant", None) + + class FakeProviderServiceMissing: + def get_datasource_credentials(self, **kwargs): + return None + + monkeypatch.setattr(notion_extractor, "DatasourceProviderService", FakeProviderServiceMissing) + with pytest.raises(Exception, match="No notion credential found"): + notion_extractor.NotionExtractor._get_access_token("tenant", "cred") + + class FakeProviderServiceFound: + def get_datasource_credentials(self, **kwargs): + return {"integration_secret": "token-from-credential"} + + monkeypatch.setattr(notion_extractor, "DatasourceProviderService", FakeProviderServiceFound) + + assert notion_extractor.NotionExtractor._get_access_token("tenant", "cred") == "token-from-credential" diff --git a/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py new file mode 100644 index 0000000000..fb3c6e52c6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py @@ -0,0 +1,79 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.text_extractor as text_module +from core.rag.extractor.text_extractor import TextExtractor + + +class TestTextExtractor: + def test_extract_success(self, tmp_path): + file_path = tmp_path / "data.txt" + file_path.write_text("hello world", encoding="utf-8") + + extractor = TextExtractor(str(file_path)) + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "hello world" + assert docs[0].metadata == {"source": str(file_path)} + + def test_extract_autodetect_success_after_decode_error(self, monkeypatch): + extractor = TextExtractor("dummy.txt", autodetect_encoding=True) + + calls = [] + + def fake_read_text(self, encoding=None): + calls.append(encoding) + if encoding is None: + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + if encoding == "bad": + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + return "decoded text" + + monkeypatch.setattr(Path, "read_text", fake_read_text, raising=True) + monkeypatch.setattr( + text_module, + "detect_file_encodings", + lambda _: [SimpleNamespace(encoding="bad"), SimpleNamespace(encoding="utf-8")], + ) + + docs = extractor.extract() + + assert docs[0].page_content == "decoded text" + assert calls == [None, "bad", "utf-8"] + + def test_extract_autodetect_all_fail_raises_runtime_error(self, monkeypatch): + extractor = TextExtractor("dummy.txt", autodetect_encoding=True) + + def always_decode_error(self, encoding=None): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + + monkeypatch.setattr(Path, "read_text", always_decode_error, raising=True) + monkeypatch.setattr(text_module, "detect_file_encodings", lambda _: [SimpleNamespace(encoding="latin-1")]) + + with pytest.raises(RuntimeError, match="all detected encodings failed"): + extractor.extract() + + def test_extract_decode_error_without_autodetect_raises_runtime_error(self, monkeypatch): + extractor = TextExtractor("dummy.txt", autodetect_encoding=False) + + def always_decode_error(self, encoding=None): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + + monkeypatch.setattr(Path, "read_text", always_decode_error, raising=True) + + with pytest.raises(RuntimeError, match="specified encoding failed"): + extractor.extract() + + def test_extract_wraps_non_decode_exceptions(self, monkeypatch): + extractor = TextExtractor("dummy.txt") + + def raise_other(self, encoding=None): + raise OSError("io error") + + monkeypatch.setattr(Path, "read_text", raise_other, raising=True) + + with pytest.raises(RuntimeError, match="Error loading dummy.txt"): + extractor.extract() diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 0792ada194..64eb89590a 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -3,9 +3,12 @@ import io import os import tempfile +from collections import UserDict from pathlib import Path from types import SimpleNamespace +from unittest.mock import MagicMock +import pytest from docx import Document from docx.oxml import OxmlElement from docx.oxml.ns import qn @@ -136,7 +139,7 @@ def test_extract_images_from_docx(monkeypatch): monkeypatch.setattr(we, "UploadFile", FakeUploadFile) # Patch external image fetcher - def fake_get(url: str): + def fake_get(url: str, **kwargs): assert url == "https://example.com/image.png" return SimpleNamespace(status_code=200, headers={"Content-Type": "image/png"}, content=external_bytes) @@ -203,10 +206,8 @@ def test_extract_images_from_docx_uses_internal_files_url(): finally: # Restore original values - if original_files_url is not None: - dify_config.FILES_URL = original_files_url - if original_internal_files_url is not None: - dify_config.INTERNAL_FILES_URL = original_internal_files_url + dify_config.FILES_URL = original_files_url + dify_config.INTERNAL_FILES_URL = original_internal_files_url def test_extract_hyperlinks(monkeypatch): @@ -314,3 +315,405 @@ def test_extract_legacy_hyperlinks(monkeypatch): finally: if os.path.exists(tmp_path): os.remove(tmp_path) + + +def test_init_rejects_invalid_url_status(monkeypatch): + class FakeResponse: + status_code = 404 + content = b"" + closed = False + + def close(self): + self.closed = True + + fake_response = FakeResponse() + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=lambda url, **kwargs: fake_response)) + + with pytest.raises(ValueError, match="returned status code 404"): + WordExtractor("https://example.com/missing.docx", "tenant", "user") + + assert fake_response.closed is True + + +def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path): + target_file = tmp_path / "expanded.docx" + target_file.write_bytes(b"docx") + + monkeypatch.setattr(we.os.path, "expanduser", lambda p: str(target_file)) + monkeypatch.setattr( + we.os.path, + "isfile", + lambda p: p == str(target_file), + ) + + extractor = WordExtractor("~/expanded.docx", "tenant", "user") + assert extractor.file_path == str(target_file) + + monkeypatch.setattr(we.os.path, "isfile", lambda p: False) + with pytest.raises(ValueError, match="is not a valid file or url"): + WordExtractor("not-a-file", "tenant", "user") + + +def test_del_closes_temp_file(): + extractor = object.__new__(WordExtractor) + extractor.temp_file = MagicMock() + + WordExtractor.__del__(extractor) + + extractor.temp_file.close.assert_called_once() + + +def test_extract_images_handles_invalid_external_cases(monkeypatch): + class FakeTargetRef: + def __contains__(self, item): + return item == "image" + + def split(self, sep): + return [None] + + rel_invalid_url = SimpleNamespace(is_external=True, target_ref="image-no-url") + rel_request_error = SimpleNamespace(is_external=True, target_ref="https://example.com/image-error") + rel_unknown_mime = SimpleNamespace(is_external=True, target_ref="https://example.com/image-unknown") + rel_internal_none_ext = SimpleNamespace(is_external=False, target_ref=FakeTargetRef(), target_part=object()) + + doc = SimpleNamespace( + part=SimpleNamespace( + rels={ + "r1": rel_invalid_url, + "r2": rel_request_error, + "r3": rel_unknown_mime, + "r4": rel_internal_none_ext, + } + ) + ) + + def fake_get(url, **kwargs): + if "image-error" in url: + raise RuntimeError("network") + return SimpleNamespace(status_code=200, headers={"Content-Type": "application/unknown"}, content=b"x") + + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get)) + db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda obj: None, commit=MagicMock())) + monkeypatch.setattr(we, "db", db_stub) + monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda key, data: None)) + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + + extractor = object.__new__(WordExtractor) + extractor.tenant_id = "tenant" + extractor.user_id = "user" + + result = extractor._extract_images_from_docx(doc) + + assert result == {} + db_stub.session.commit.assert_called_once() + + +def test_table_to_markdown_and_parse_helpers(monkeypatch): + extractor = object.__new__(WordExtractor) + + table = SimpleNamespace( + rows=[ + SimpleNamespace(cells=[1, 2]), + SimpleNamespace(cells=[3, 4]), + ] + ) + parse_row_mock = MagicMock(side_effect=[["H1", "H2"], ["A", "B"]]) + monkeypatch.setattr(extractor, "_parse_row", parse_row_mock) + + markdown = extractor._table_to_markdown(table, {}) + assert markdown == "| H1 | H2 |\n| --- | --- |\n| A | B |" + + class FakeBlip: + def __init__(self, image_id): + self.image_id = image_id + + def get(self, key): + return self.image_id + + class FakeRunChild: + def __init__(self, blips, text=""): + self._blips = blips + self.text = text + self.tag = qn("w:r") + + def xpath(self, pattern): + if pattern == ".//a:blip": + return self._blips + return [] + + class FakeRun: + def __init__(self, element, paragraph): + # Mirror the subset used by _parse_cell_paragraph + self.element = element + self.text = getattr(element, "text", "") + + # Patch we.Run so our lightweight child objects work with the extractor + monkeypatch.setattr(we, "Run", FakeRun) + + image_part = object() + paragraph = SimpleNamespace( + _element=[ + FakeRunChild([FakeBlip(None), FakeBlip("ext"), FakeBlip("int")], text=""), + FakeRunChild([], text="plain"), + ], + part=SimpleNamespace( + rels={ + "ext": SimpleNamespace(is_external=True), + "int": SimpleNamespace(is_external=False, target_part=image_part), + } + ), + ) + + image_map = {"ext": "EXT-IMG", image_part: "INT-IMG"} + assert extractor._parse_cell_paragraph(paragraph, image_map) == "EXT-IMGINT-IMGplain" + + cell = SimpleNamespace(paragraphs=[paragraph, paragraph]) + assert extractor._parse_cell(cell, image_map) == "EXT-IMGINT-IMGplain" + + +def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monkeypatch): + extractor = object.__new__(WordExtractor) + + ext_image_id = "ext-image" + int_embed_id = "int-embed" + shape_ext_id = "shape-ext" + shape_int_id = "shape-int" + + internal_part = object() + shape_internal_part = object() + + class Rels(UserDict): + def get(self, key, default=None): + if key == "link-bad": + raise RuntimeError("cannot resolve relation") + return super().get(key, default) + + rels = Rels( + { + ext_image_id: SimpleNamespace(is_external=True, target_ref="https://img/ext.png"), + int_embed_id: SimpleNamespace(is_external=False, target_part=internal_part), + shape_ext_id: SimpleNamespace(is_external=True, target_ref="https://img/shape.png"), + shape_int_id: SimpleNamespace(is_external=False, target_part=shape_internal_part), + "link-ok": SimpleNamespace(is_external=True, target_ref="https://example.com"), + } + ) + + image_map = { + ext_image_id: "[EXT]", + internal_part: "[INT]", + shape_ext_id: "[SHAPE_EXT]", + shape_internal_part: "[SHAPE_INT]", + } + + class FakeBlip: + def __init__(self, embed_id): + self.embed_id = embed_id + + def get(self, key): + return self.embed_id + + class FakeDrawing: + def __init__(self, embed_ids): + self.embed_ids = embed_ids + + def findall(self, pattern): + return [FakeBlip(embed_id) for embed_id in self.embed_ids] + + class FakeNode: + def __init__(self, text=None, attrs=None): + self.text = text + self._attrs = attrs or {} + + def get(self, key): + return self._attrs.get(key) + + class FakeShape: + def __init__(self, bin_id=None, img_id=None): + self.bin_id = bin_id + self.img_id = img_id + + def find(self, pattern): + if "binData" in pattern and self.bin_id: + return FakeNode( + text="shape", + attrs={"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id": self.bin_id}, + ) + if "imagedata" in pattern and self.img_id: + return FakeNode(attrs={"id": self.img_id}) + return None + + class FakeChild: + def __init__( + self, + tag, + text="", + fld_chars=None, + instr_texts=None, + drawings=None, + shapes=None, + attrs=None, + hyperlink_runs=None, + ): + self.tag = tag + self.text = text + self._fld_chars = fld_chars or [] + self._instr_texts = instr_texts or [] + self._drawings = drawings or [] + self._shapes = shapes or [] + self._attrs = attrs or {} + self._hyperlink_runs = hyperlink_runs or [] + + def findall(self, pattern): + if pattern == qn("w:fldChar"): + return self._fld_chars + if pattern == qn("w:instrText"): + return self._instr_texts + if pattern == qn("w:r"): + return self._hyperlink_runs + if pattern.endswith("}drawing"): + return self._drawings + if pattern.endswith("}pict"): + return self._shapes + return [] + + def get(self, key): + return self._attrs.get(key) + + class FakeRun: + def __init__(self, element, paragraph): + self.element = element + self.text = getattr(element, "text", "") + + paragraph_main = SimpleNamespace( + _element=[ + FakeChild( + qn("w:r"), + text="run-text", + drawings=[FakeDrawing([ext_image_id, int_embed_id])], + shapes=[FakeShape(bin_id=shape_ext_id, img_id=shape_int_id)], + ), + FakeChild( + qn("w:r"), + text="", + drawings=[], + shapes=[FakeShape(bin_id=shape_ext_id)], + ), + FakeChild( + qn("w:hyperlink"), + attrs={qn("r:id"): "link-ok"}, + hyperlink_runs=[FakeChild(qn("w:r"), text="LinkText")], + ), + FakeChild( + qn("w:hyperlink"), + attrs={qn("r:id"): "link-bad"}, + hyperlink_runs=[FakeChild(qn("w:r"), text="BrokenLink")], + ), + ] + ) + paragraph_empty = SimpleNamespace(_element=[FakeChild(qn("w:r"), text=" ")]) + + fake_doc = SimpleNamespace( + part=SimpleNamespace(rels=rels, related_parts={int_embed_id: internal_part}), + paragraphs=[paragraph_main, paragraph_empty], + tables=[SimpleNamespace(rows=[])], + element=SimpleNamespace( + body=[SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:tbl")] + ), + ) + + monkeypatch.setattr(we, "DocxDocument", lambda _: fake_doc) + monkeypatch.setattr(we, "Run", FakeRun) + monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: image_map) + monkeypatch.setattr(extractor, "_table_to_markdown", lambda table, image_map: "TABLE-MARKDOWN") + logger_exception = MagicMock() + monkeypatch.setattr(we.logger, "exception", logger_exception) + + content = extractor.parse_docx("dummy.docx") + + assert "[EXT]" in content + assert "[INT]" in content + assert "[SHAPE_EXT]" in content + assert "[LinkText](https://example.com)" in content + assert "BrokenLink" in content + assert "TABLE-MARKDOWN" in content + logger_exception.assert_called_once() + + +def test_parse_cell_paragraph_hyperlink_in_table_cell_http(): + doc = Document() + table = doc.add_table(rows=1, cols=1) + cell = table.cell(0, 0) + p = cell.paragraphs[0] + + # Build modern hyperlink inside table cell + r_id = "rIdHttp1" + hyperlink = OxmlElement("w:hyperlink") + hyperlink.set(qn("r:id"), r_id) + + run_elem = OxmlElement("w:r") + t = OxmlElement("w:t") + t.text = "Dify" + run_elem.append(t) + hyperlink.append(run_elem) + p._p.append(hyperlink) + + # Relationship for external http link + doc.part.rels.add_relationship( + "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink", + "https://dify.ai", + r_id, + is_external=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + reopened = Document(tmp_path) + para = reopened.tables[0].cell(0, 0).paragraphs[0] + extractor = object.__new__(WordExtractor) + out = extractor._parse_cell_paragraph(para, {}) + assert out == "[Dify](https://dify.ai)" + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + +def test_parse_cell_paragraph_hyperlink_in_table_cell_mailto(): + doc = Document() + table = doc.add_table(rows=1, cols=1) + cell = table.cell(0, 0) + p = cell.paragraphs[0] + + r_id = "rIdMail1" + hyperlink = OxmlElement("w:hyperlink") + hyperlink.set(qn("r:id"), r_id) + + run_elem = OxmlElement("w:r") + t = OxmlElement("w:t") + t.text = "john@test.com" + run_elem.append(t) + hyperlink.append(run_elem) + p._p.append(hyperlink) + + doc.part.rels.add_relationship( + "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink", + "mailto:john@test.com", + r_id, + is_external=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + reopened = Document(tmp_path) + para = reopened.tables[0].cell(0, 0).paragraphs[0] + extractor = object.__new__(WordExtractor) + out = extractor._parse_cell_paragraph(para, {}) + assert out == "[john@test.com](mailto:john@test.com)" + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) diff --git a/api/tests/unit_tests/core/rag/extractor/unstructured/test_unstructured_extractors.py b/api/tests/unit_tests/core/rag/extractor/unstructured/test_unstructured_extractors.py new file mode 100644 index 0000000000..26ce333e11 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/unstructured/test_unstructured_extractors.py @@ -0,0 +1,300 @@ +"""Unit tests for unstructured extractors and their local/API partitioning paths.""" + +import base64 +import sys +import types +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.unstructured.unstructured_epub_extractor as epub_module +from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor +from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor +from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor +from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor +from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor +from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor +from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor +from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor + + +def _register_module(monkeypatch: pytest.MonkeyPatch, name: str, **attrs: object) -> types.ModuleType: + module = types.ModuleType(name) + for k, v in attrs.items(): + setattr(module, k, v) + monkeypatch.setitem(sys.modules, name, module) + return module + + +def _register_unstructured_packages(monkeypatch: pytest.MonkeyPatch) -> None: + _register_module(monkeypatch, "unstructured", __path__=[]) + _register_module(monkeypatch, "unstructured.partition", __path__=[]) + _register_module(monkeypatch, "unstructured.chunking", __path__=[]) + _register_module(monkeypatch, "unstructured.file_utils", __path__=[]) + + +def _install_chunk_by_title(monkeypatch: pytest.MonkeyPatch, chunks: list[SimpleNamespace]) -> None: + _register_unstructured_packages(monkeypatch) + + def chunk_by_title( + elements: list[SimpleNamespace], max_characters: int, combine_text_under_n_chars: int + ) -> list[SimpleNamespace]: + return chunks + + _register_module(monkeypatch, "unstructured.chunking.title", chunk_by_title=chunk_by_title) + + +class TestUnstructuredMarkdownMsgXml: + def test_markdown_extractor_without_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text=" chunk-1 "), SimpleNamespace(text=" chunk-2 ")]) + _register_module( + monkeypatch, "unstructured.partition.md", partition_md=lambda filename: [SimpleNamespace(text="x")] + ) + + docs = UnstructuredMarkdownExtractor("/tmp/file.md").extract() + + assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"] + + def test_markdown_extractor_with_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text=" via-api ")]) + calls = {} + + def partition_via_api(filename, api_url, api_key): + calls.update({"filename": filename, "api_url": api_url, "api_key": api_key}) + return [SimpleNamespace(text="ignored")] + + _register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api) + + docs = UnstructuredMarkdownExtractor("/tmp/file.md", api_url="https://u", api_key="k").extract() + + assert docs[0].page_content == "via-api" + assert calls == {"filename": "/tmp/file.md", "api_url": "https://u", "api_key": "k"} + + def test_msg_extractor_local(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="msg-doc")]) + _register_module( + monkeypatch, "unstructured.partition.msg", partition_msg=lambda filename: [SimpleNamespace(text="x")] + ) + + assert UnstructuredMsgExtractor("/tmp/file.msg").extract()[0].page_content == "msg-doc" + + def test_msg_extractor_with_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="msg-doc")]) + calls = {} + + def partition_via_api(filename, api_url, api_key): + calls.update({"filename": filename, "api_url": api_url, "api_key": api_key}) + return [SimpleNamespace(text="x")] + + _register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api) + + assert ( + UnstructuredMsgExtractor("/tmp/file.msg", api_url="https://u", api_key="k").extract()[0].page_content + == "msg-doc" + ) + assert calls["filename"] == "/tmp/file.msg" + + def test_xml_extractor_local_and_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="xml-doc")]) + + xml_calls = {} + + def partition_xml(filename, xml_keep_tags): + xml_calls.update({"filename": filename, "xml_keep_tags": xml_keep_tags}) + return [SimpleNamespace(text="x")] + + _register_module(monkeypatch, "unstructured.partition.xml", partition_xml=partition_xml) + + assert UnstructuredXmlExtractor("/tmp/file.xml").extract()[0].page_content == "xml-doc" + assert xml_calls == {"filename": "/tmp/file.xml", "xml_keep_tags": True} + + api_calls = {} + + def partition_via_api(filename, api_url, api_key): + api_calls.update({"filename": filename, "api_url": api_url, "api_key": api_key}) + return [SimpleNamespace(text="x")] + + _register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api) + + assert ( + UnstructuredXmlExtractor("/tmp/file.xml", api_url="https://u", api_key="k").extract()[0].page_content + == "xml-doc" + ) + assert api_calls["filename"] == "/tmp/file.xml" + + +class TestUnstructuredEmailAndEpub: + def test_email_extractor_local_decodes_html_and_suppresses_decode_errors(self, monkeypatch): + _register_unstructured_packages(monkeypatch) + captured = {} + + def chunk_by_title( + elements: list[SimpleNamespace], max_characters: int, combine_text_under_n_chars: int + ) -> list[SimpleNamespace]: + captured["elements"] = list(elements) + return [SimpleNamespace(text=" chunked-email ")] + + _register_module(monkeypatch, "unstructured.chunking.title", chunk_by_title=chunk_by_title) + + html = "Hello Email
" + encoded_html = base64.b64encode(html.encode("utf-8")).decode("utf-8") + bad_base64 = "not-base64" + + elements = [SimpleNamespace(text=encoded_html), SimpleNamespace(text=bad_base64)] + _register_module(monkeypatch, "unstructured.partition.email", partition_email=lambda filename: elements) + + docs = UnstructuredEmailExtractor("/tmp/file.eml").extract() + + assert docs[0].page_content == "chunked-email" + chunk_elements = captured["elements"] + assert "Hello Email" in chunk_elements[0].text + assert chunk_elements[1].text == bad_base64 + + def test_email_extractor_with_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="api-email")]) + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="abc")], + ) + + docs = UnstructuredEmailExtractor("/tmp/file.eml", api_url="https://u", api_key="k").extract() + + assert docs[0].page_content == "api-email" + + def test_epub_extractor_local_and_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="epub-doc")]) + + calls = {"download": 0, "partition": 0} + + def fake_download_pandoc(): + calls["download"] += 1 + + def partition_epub(filename, xml_keep_tags): + calls["partition"] += 1 + assert xml_keep_tags is True + return [SimpleNamespace(text="x")] + + monkeypatch.setattr(epub_module.pypandoc, "download_pandoc", fake_download_pandoc) + _register_module(monkeypatch, "unstructured.partition.epub", partition_epub=partition_epub) + + docs = UnstructuredEpubExtractor("/tmp/file.epub").extract() + + assert docs[0].page_content == "epub-doc" + assert calls == {"download": 1, "partition": 1} + + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="x")], + ) + + docs = UnstructuredEpubExtractor("/tmp/file.epub", api_url="https://u", api_key="k").extract() + assert docs[0].page_content == "epub-doc" + + +class TestUnstructuredPPTAndPPTX: + def test_ppt_extractor_requires_api_url(self): + with pytest.raises(NotImplementedError, match="Unstructured API Url is not configured"): + UnstructuredPPTExtractor("/tmp/file.ppt").extract() + + def test_ppt_extractor_groups_text_by_page(self, monkeypatch): + _register_unstructured_packages(monkeypatch) + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [ + SimpleNamespace(text="A", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="B", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="skip", metadata=SimpleNamespace(page_number=None)), + SimpleNamespace(text="C", metadata=SimpleNamespace(page_number=2)), + ], + ) + + docs = UnstructuredPPTExtractor("/tmp/file.ppt", api_url="https://u", api_key="k").extract() + + assert [doc.page_content for doc in docs] == ["A\nB", "C"] + + def test_pptx_extractor_local_and_api(self, monkeypatch): + _register_unstructured_packages(monkeypatch) + _register_module( + monkeypatch, + "unstructured.partition.pptx", + partition_pptx=lambda filename: [ + SimpleNamespace(text="P1", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="P2", metadata=SimpleNamespace(page_number=2)), + SimpleNamespace(text="Skip", metadata=SimpleNamespace(page_number=None)), + ], + ) + + docs = UnstructuredPPTXExtractor("/tmp/file.pptx").extract() + assert [doc.page_content for doc in docs] == ["P1", "P2"] + + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [ + SimpleNamespace(text="X", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="Y", metadata=SimpleNamespace(page_number=1)), + ], + ) + + docs = UnstructuredPPTXExtractor("/tmp/file.pptx", api_url="https://u", api_key="k").extract() + assert [doc.page_content for doc in docs] == ["X\nY"] + + +class TestUnstructuredWord: + def _install_doc_modules(self, monkeypatch, version: str, filetype_value): + _register_unstructured_packages(monkeypatch) + + class FileType: + DOC = "doc" + + _register_module(monkeypatch, "unstructured.__version__", __version__=version) + _register_module( + monkeypatch, + "unstructured.file_utils.filetype", + FileType=FileType, + detect_filetype=lambda filename: filetype_value, + ) + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="api-doc")], + ) + _register_module( + monkeypatch, + "unstructured.partition.docx", + partition_docx=lambda filename: [SimpleNamespace(text="docx-doc")], + ) + _register_module( + monkeypatch, + "unstructured.chunking.title", + chunk_by_title=lambda elements, max_characters, combine_text_under_n_chars: [ + SimpleNamespace(text="chunk-1"), + SimpleNamespace(text="chunk-2"), + ], + ) + + def test_word_extractor_rejects_doc_on_old_unstructured_version(self, monkeypatch): + self._install_doc_modules(monkeypatch, version="0.4.10", filetype_value="doc") + + with pytest.raises(ValueError, match="Partitioning .doc files is only supported"): + UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract() + + def test_word_extractor_doc_and_docx_paths(self, monkeypatch): + self._install_doc_modules(monkeypatch, version="0.4.11", filetype_value="doc") + + docs = UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract() + assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"] + + self._install_doc_modules(monkeypatch, version="0.5.0", filetype_value="not-doc") + docs = UnstructuredWordExtractor("/tmp/file.docx", "https://u", "k").extract() + assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"] + + def test_word_extractor_magic_import_error_fallback_to_extension(self, monkeypatch): + self._install_doc_modules(monkeypatch, version="0.4.10", filetype_value="not-used") + monkeypatch.setitem(sys.modules, "magic", None) + + with pytest.raises(ValueError, match="Partitioning .doc files is only supported"): + UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract() diff --git a/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py b/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py new file mode 100644 index 0000000000..d758be218a --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py @@ -0,0 +1,434 @@ +"""Unit tests for WaterCrawl client, provider, and extractor behavior.""" + +import json +from typing import Any +from unittest.mock import MagicMock + +import pytest + +import core.rag.extractor.watercrawl.client as client_module +from core.rag.extractor.watercrawl.client import BaseAPIClient, WaterCrawlAPIClient +from core.rag.extractor.watercrawl.exceptions import ( + WaterCrawlAuthenticationError, + WaterCrawlBadRequestError, + WaterCrawlPermissionError, +) +from core.rag.extractor.watercrawl.extractor import WaterCrawlWebExtractor +from core.rag.extractor.watercrawl.provider import WaterCrawlProvider + + +def _response( + status_code: int, + json_data: dict[str, Any] | None = None, + content_type: str = "application/json", + content: bytes = b"", + text: str = "", +) -> MagicMock: + response = MagicMock() + response.status_code = status_code + response.headers = {"Content-Type": content_type} + response.content = content + response.text = text + response.json.return_value = json_data if json_data is not None else {} + response.raise_for_status.return_value = None + response.close.return_value = None + return response + + +class TestWaterCrawlExceptions: + def test_bad_request_error_properties_and_string(self): + response = _response(400, {"message": "bad request", "errors": {"url": ["invalid"]}}) + + err = WaterCrawlBadRequestError(response) + parsed_errors = json.loads(err.flat_errors) + + assert err.status_code == 400 + assert err.message == "bad request" + assert "url" in parsed_errors + assert any("invalid" in str(item) for item in parsed_errors["url"]) + assert "WaterCrawlBadRequestError" in str(err) + + def test_permission_and_authentication_error_strings(self): + response = _response(403, {"message": "quota exceeded", "errors": {}}) + + permission = WaterCrawlPermissionError(response) + authentication = WaterCrawlAuthenticationError(response) + + assert "exceeding your WaterCrawl API limits" in str(permission) + assert "API key is invalid or expired" in str(authentication) + + +class TestBaseAPIClient: + def test_init_session_builds_expected_headers(self, monkeypatch): + captured = {} + + def fake_client(**kwargs): + captured.update(kwargs) + return "session" + + monkeypatch.setattr(client_module.httpx, "Client", fake_client) + + client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev") + + assert client.session == "session" + assert captured["headers"]["X-API-Key"] == "k" + assert captured["headers"]["User-Agent"] == "WaterCrawl-Plugin" + + def test_request_stream_and_non_stream_paths(self, monkeypatch): + class FakeSession: + def __init__(self): + self.request_calls = [] + self.build_calls = [] + self.send_calls = [] + + def request(self, method, url, params=None, json=None, **kwargs): + self.request_calls.append((method, url, params, json, kwargs)) + return "non-stream-response" + + def build_request(self, method, url, params=None, json=None): + req = (method, url, params, json) + self.build_calls.append(req) + return req + + def send(self, request, stream=False, **kwargs): + self.send_calls.append((request, stream, kwargs)) + return "stream-response" + + fake_session = FakeSession() + monkeypatch.setattr(BaseAPIClient, "init_session", lambda self: fake_session) + + client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev") + + assert client._request("GET", "/v1/items", query_params={"a": 1}) == "non-stream-response" + assert fake_session.request_calls[0][1] == "https://watercrawl.dev/v1/items" + + assert client._request("GET", "/v1/items", stream=True) == "stream-response" + assert fake_session.build_calls + assert fake_session.send_calls[0][1] is True + + def test_http_method_helpers_delegate_to_request(self, monkeypatch): + monkeypatch.setattr(BaseAPIClient, "init_session", lambda self: MagicMock()) + client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev") + + calls = [] + + def fake_request(method, endpoint, query_params=None, data=None, **kwargs): + calls.append((method, endpoint, query_params, data)) + return "ok" + + monkeypatch.setattr(client, "_request", fake_request) + + assert client._get("/a") == "ok" + assert client._post("/b", data={"x": 1}) == "ok" + assert client._put("/c", data={"x": 2}) == "ok" + assert client._delete("/d") == "ok" + assert client._patch("/e", data={"x": 3}) == "ok" + assert [c[0] for c in calls] == ["GET", "POST", "PUT", "DELETE", "PATCH"] + + +class TestWaterCrawlAPIClient: + def test_process_eventstream_and_download(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + response = MagicMock() + response.iter_lines.return_value = [ + b"event: keep-alive", + b'data: {"type":"result","data":{"result":"http://x"}}', + b'data: {"type":"log","data":{"msg":"ok"}}', + ] + + monkeypatch.setattr(client, "download_result", lambda data: {"result": {"markdown": "body"}, "url": "u"}) + + events = list(client.process_eventstream(response, download=True)) + + assert events[0]["data"]["result"]["markdown"] == "body" + assert events[1]["type"] == "log" + response.close.assert_called_once() + + @pytest.mark.parametrize( + ("status", "expected_exception"), + [ + (401, WaterCrawlAuthenticationError), + (403, WaterCrawlPermissionError), + (422, WaterCrawlBadRequestError), + ], + ) + def test_process_response_error_statuses(self, status: int, expected_exception: type[Exception]): + client = WaterCrawlAPIClient(api_key="k") + + with pytest.raises(expected_exception): + client.process_response(_response(status, {"message": "bad", "errors": {"url": ["x"]}})) + + def test_process_response_204_returns_none(self): + client = WaterCrawlAPIClient(api_key="k") + assert client.process_response(_response(204, None)) is None + + def test_process_response_json_payloads(self): + client = WaterCrawlAPIClient(api_key="k") + assert client.process_response(_response(200, {"ok": True})) == {"ok": True} + assert client.process_response(_response(200, None)) == {} + + def test_process_response_octet_stream_returns_bytes(self): + client = WaterCrawlAPIClient(api_key="k") + assert ( + client.process_response(_response(200, content_type="application/octet-stream", content=b"bin")) == b"bin" + ) + + def test_process_response_event_stream_returns_generator(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + generator = (item for item in [{"type": "result", "data": {}}]) + monkeypatch.setattr(client, "process_eventstream", lambda response, download=False: generator) + assert client.process_response(_response(200, content_type="text/event-stream")) is generator + + def test_process_response_unknown_content_type_raises(self): + client = WaterCrawlAPIClient(api_key="k") + with pytest.raises(Exception, match="Unknown response type"): + client.process_response(_response(200, content_type="text/plain", text="x")) + + def test_process_response_uses_raise_for_status(self): + client = WaterCrawlAPIClient(api_key="k") + response = _response(500, {"message": "server"}) + response.raise_for_status.side_effect = RuntimeError("http error") + + with pytest.raises(RuntimeError, match="http error"): + client.process_response(response) + + def test_endpoint_wrappers(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + monkeypatch.setattr(client, "process_response", lambda resp: "processed") + monkeypatch.setattr(client, "_get", lambda *args, **kwargs: "get-resp") + monkeypatch.setattr(client, "_post", lambda *args, **kwargs: "post-resp") + monkeypatch.setattr(client, "_delete", lambda *args, **kwargs: "delete-resp") + + assert client.get_crawl_requests_list() == "processed" + assert client.get_crawl_request("id") == "processed" + assert client.create_crawl_request(url="https://x") == "processed" + assert client.stop_crawl_request("id") == "processed" + assert client.download_crawl_request("id") == "processed" + assert client.get_crawl_request_results("id") == "processed" + + def test_monitor_crawl_request_generator_and_validation(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + monkeypatch.setattr(client, "process_response", lambda _: (x for x in [{"type": "result", "data": 1}])) + monkeypatch.setattr(client, "_get", lambda *args, **kwargs: "stream-resp") + + events = list(client.monitor_crawl_request("job-1", prefetched=True)) + assert events == [{"type": "result", "data": 1}] + + monkeypatch.setattr(client, "process_response", lambda _: [{"type": "result"}]) + with pytest.raises(ValueError, match="Generator expected"): + list(client.monitor_crawl_request("job-1")) + + def test_scrape_url_sync_and_async(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + monkeypatch.setattr(client, "create_crawl_request", lambda **kwargs: {"uuid": "job-1"}) + + async_result = client.scrape_url("https://example.com", sync=False) + assert async_result == {"uuid": "job-1"} + + monkeypatch.setattr( + client, + "monitor_crawl_request", + lambda item_id, prefetched: iter( + [{"type": "log", "data": {}}, {"type": "result", "data": {"url": "https://example.com"}}] + ), + ) + sync_result = client.scrape_url("https://example.com", sync=True) + assert sync_result == {"url": "https://example.com"} + + def test_download_result_fetches_json_and_closes(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + response = _response(200, {"markdown": "body"}) + monkeypatch.setattr(client_module.httpx, "get", lambda *args, **kwargs: response) + + result = client.download_result({"result": "https://example.com/result.json"}) + + assert result["result"] == {"markdown": "body"} + response.close.assert_called_once() + + +class TestWaterCrawlProvider: + def test_crawl_url_builds_options_and_min_wait_time(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + captured_kwargs = {} + + def create_crawl_request_spy(**kwargs): + captured_kwargs.update(kwargs) + return {"uuid": "job-1"} + + monkeypatch.setattr(provider.client, "create_crawl_request", create_crawl_request_spy) + + result = provider.crawl_url( + "https://example.com", + { + "crawl_sub_pages": True, + "limit": 5, + "max_depth": 2, + "includes": "a,b", + "excludes": "x,y", + "exclude_tags": "nav,footer", + "include_tags": "main", + "wait_time": 100, + "only_main_content": False, + }, + ) + + assert result == {"status": "active", "job_id": "job-1"} + assert captured_kwargs["url"] == "https://example.com" + assert captured_kwargs["spider_options"] == { + "max_depth": 2, + "page_limit": 5, + "allowed_domains": [], + "exclude_paths": ["x", "y"], + "include_paths": ["a", "b"], + } + assert captured_kwargs["page_options"]["exclude_tags"] == ["nav", "footer"] + assert captured_kwargs["page_options"]["include_tags"] == ["main"] + assert captured_kwargs["page_options"]["only_main_content"] is False + assert captured_kwargs["page_options"]["wait_time"] == 1000 + + def test_get_crawl_status_active_and_completed(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + + monkeypatch.setattr( + provider.client, + "get_crawl_request", + lambda job_id: { + "status": "running", + "uuid": job_id, + "options": {"spider_options": {"page_limit": 3}}, + "number_of_documents": 1, + "duration": "00:00:01.500000", + }, + ) + + active = provider.get_crawl_status("job-1") + assert active["status"] == "active" + assert active["data"] == [] + assert active["time_consuming"] == pytest.approx(1.5) + + monkeypatch.setattr( + provider.client, + "get_crawl_request", + lambda job_id: { + "status": "completed", + "uuid": job_id, + "options": {"spider_options": {"page_limit": 2}}, + "number_of_documents": 2, + "duration": "00:00:02.000000", + }, + ) + monkeypatch.setattr(provider, "_get_results", lambda crawl_request_id, query_params=None: iter([{"url": "u"}])) + + completed = provider.get_crawl_status("job-2") + assert completed["status"] == "completed" + assert completed["data"] == [{"url": "u"}] + + def test_get_crawl_url_data_and_scrape(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + + monkeypatch.setattr(provider, "scrape_url", lambda url: {"source_url": url}) + assert provider.get_crawl_url_data("", "https://example.com") == {"source_url": "https://example.com"} + + monkeypatch.setattr(provider, "_get_results", lambda job_id, query_params=None: iter([{"source_url": "u1"}])) + assert provider.get_crawl_url_data("job", "u1") == {"source_url": "u1"} + + monkeypatch.setattr(provider, "_get_results", lambda job_id, query_params=None: iter([])) + assert provider.get_crawl_url_data("job", "u1") is None + + def test_structure_data_validation_and_get_results_pagination(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + + with pytest.raises(ValueError, match="Invalid result object"): + provider._structure_data({"result": "not-a-dict"}) + + structured = provider._structure_data( + { + "url": "https://example.com", + "result": { + "metadata": {"title": "Title", "description": "Desc"}, + "markdown": "Body", + }, + } + ) + assert structured["title"] == "Title" + assert structured["markdown"] == "Body" + + responses = [ + { + "results": [ + { + "url": "https://a", + "result": {"metadata": {"title": "A", "description": "DA"}, "markdown": "MA"}, + } + ], + "next": "next-page", + }, + {"results": [], "next": None}, + ] + + monkeypatch.setattr( + provider.client, + "get_crawl_request_results", + lambda crawl_request_id, page, page_size, query_params: responses.pop(0), + ) + + results = list(provider._get_results("job-1")) + assert len(results) == 1 + assert results[0]["source_url"] == "https://a" + + def test_scrape_url_uses_client_and_structure(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + monkeypatch.setattr( + provider.client, "scrape_url", lambda **kwargs: {"result": {"metadata": {}, "markdown": "m"}, "url": "u"} + ) + + result = provider.scrape_url("u") + + assert result["source_url"] == "u" + + +class TestWaterCrawlWebExtractor: + def test_extract_crawl_and_scrape_modes(self, monkeypatch): + monkeypatch.setattr( + "core.rag.extractor.watercrawl.extractor.WebsiteService.get_crawl_url_data", + lambda job_id, provider, url, tenant_id: { + "markdown": "crawl", + "source_url": url, + "description": "d", + "title": "t", + }, + ) + monkeypatch.setattr( + "core.rag.extractor.watercrawl.extractor.WebsiteService.get_scrape_url_data", + lambda provider, url, tenant_id, only_main_content: { + "markdown": "scrape", + "source_url": url, + "description": "d", + "title": "t", + }, + ) + + crawl_extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + scrape_extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="scrape") + + assert crawl_extractor.extract()[0].page_content == "crawl" + assert scrape_extractor.extract()[0].page_content == "scrape" + + def test_extract_crawl_returns_empty_when_service_returns_none(self, monkeypatch): + monkeypatch.setattr( + "core.rag.extractor.watercrawl.extractor.WebsiteService.get_crawl_url_data", + lambda job_id, provider, url, tenant_id: None, + ) + + extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + + assert extractor.extract() == [] + + def test_extract_unknown_mode_returns_empty(self): + extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="other") + + assert extractor.extract() == [] diff --git a/api/tests/unit_tests/core/rag/indexing/processor/conftest.py b/api/tests/unit_tests/core/rag/indexing/processor/conftest.py new file mode 100644 index 0000000000..2a3860e107 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/conftest.py @@ -0,0 +1,33 @@ +from contextlib import AbstractContextManager, nullcontext +from typing import Any + +import pytest + + +class _FakeFlaskApp: + def app_context(self) -> AbstractContextManager[None]: + return nullcontext() + + +class _FakeExecutor: + def __init__(self, future: Any) -> None: + self._future = future + + def __enter__(self) -> "_FakeExecutor": + return self + + def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> bool: + return False + + def submit(self, func: object, preview: object) -> Any: + return self._future + + +@pytest.fixture +def fake_flask_app() -> _FakeFlaskApp: + return _FakeFlaskApp() + + +@pytest.fixture +def fake_executor_cls() -> type[_FakeExecutor]: + return _FakeExecutor diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py new file mode 100644 index 0000000000..2451db70b6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -0,0 +1,629 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.models.document import AttachmentDocument, Document +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from dify_graph.model_runtime.entities.model_entities import ModelFeature + + +class TestParagraphIndexProcessor: + @pytest.fixture + def processor(self) -> ParagraphIndexProcessor: + return ParagraphIndexProcessor() + + @pytest.fixture + def dataset(self) -> Mock: + dataset = Mock() + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = "high_quality" + dataset.is_multimodal = True + return dataset + + @pytest.fixture + def dataset_document(self) -> Mock: + document = Mock() + document.id = "doc-1" + document.created_by = "user-1" + return document + + @pytest.fixture + def process_rule(self) -> dict: + return { + "mode": "custom", + "rules": {"segmentation": {"max_tokens": 256, "chunk_overlap": 10, "separator": "\n"}}, + } + + def _rules(self) -> SimpleNamespace: + segmentation = SimpleNamespace(max_tokens=256, chunk_overlap=10, separator="\n") + return SimpleNamespace(segmentation=segmentation) + + def _llm_result(self, content: str = "summary") -> LLMResult: + return LLMResult( + model="llm-model", + message=AssistantPromptMessage(content=content), + usage=LLMUsage.empty_usage(), + ) + + def test_extract_forwards_automatic_flag(self, processor: ParagraphIndexProcessor) -> None: + extract_setting = Mock() + expected_docs = [Document(page_content="chunk", metadata={})] + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.ExtractProcessor.extract" + ) as mock_extract: + mock_extract.return_value = expected_docs + docs = processor.extract(extract_setting, process_rule_mode="hierarchical") + + assert docs == expected_docs + mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True) + + def test_transform_validates_process_rule(self, processor: ParagraphIndexProcessor) -> None: + with pytest.raises(ValueError, match="No process rule found"): + processor.transform([Document(page_content="text", metadata={})], process_rule=None) + + with pytest.raises(ValueError, match="No rules found in process rule"): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) + + def test_transform_validates_segmentation(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + rules_without_segmentation = SimpleNamespace(segmentation=None) + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate", + return_value=rules_without_segmentation, + ): + with pytest.raises(ValueError, match="No segmentation found in rules"): + processor.transform( + [Document(page_content="text", metadata={})], + process_rule={"mode": "custom", "rules": {"enabled": True}}, + ) + + def test_transform_builds_split_documents(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + source_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) + splitter = Mock() + splitter.split_documents.return_value = [ + Document(page_content=".first", metadata={}), + Document(page_content=" ", metadata={}), + ] + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate", + return_value=self._rules(), + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.CleanProcessor.clean", + return_value=".first", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.remove_leading_symbols", + side_effect=lambda text: text.lstrip("."), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ), + ): + documents = processor.transform([source_document], process_rule=process_rule) + + assert len(documents) == 1 + assert documents[0].page_content == "first" + assert documents[0].attachments is not None + assert documents[0].metadata["doc_hash"] == "hash" + + def test_transform_automatic_mode_uses_default_rules(self, processor: ParagraphIndexProcessor) -> None: + splitter = Mock() + splitter.split_documents.return_value = [Document(page_content="text", metadata={})] + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate", + return_value=self._rules(), + ) as mock_validate, + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.CleanProcessor.clean", + side_effect=lambda text, _: text, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.remove_leading_symbols", + side_effect=lambda text: text, + ), + patch.object(processor, "_get_content_files", return_value=[]), + ): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "automatic"}) + + assert mock_validate.call_count == 1 + + def test_load_creates_vector_and_multimodal_when_high_quality( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + docs = [Document(page_content="chunk", metadata={})] + multimodal_docs = [AttachmentDocument(page_content="image", metadata={})] + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls, + ): + processor.load(dataset, docs, multimodal_documents=multimodal_docs) + vector = mock_vector_cls.return_value + vector.create.assert_called_once_with(docs) + vector.create_multimodal.assert_called_once_with(multimodal_docs) + mock_keyword_cls.assert_not_called() + + def test_load_uses_keyword_add_texts_with_keywords_when_economy( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + dataset.indexing_technique = "economy" + docs = [Document(page_content="chunk", metadata={})] + + with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: + processor.load(dataset, docs, keywords_list=["k1", "k2"]) + + mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs, keywords_list=["k1", "k2"]) + + def test_load_uses_keyword_add_texts_without_keywords_when_economy( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + dataset.indexing_technique = "economy" + docs = [Document(page_content="chunk", metadata={})] + + with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: + processor.load(dataset, docs) + + mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs) + + def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: + segment_query = Mock() + segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + session = Mock() + session.query.return_value = segment_query + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls, + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, ["node-1"], delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, ["seg-1"]) + vector.delete_by_ids.assert_called_once_with(["node-1"]) + + def test_clean_economy_deletes_summaries_and_keywords( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + dataset.indexing_technique = "economy" + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls, + ): + processor.clean(dataset, None, delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, None) + mock_keyword_cls.return_value.delete.assert_called_once() + + def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: + dataset.indexing_technique = "economy" + with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: + processor.clean(dataset, ["node-2"], with_keywords=True) + + mock_keyword_cls.return_value.delete_by_ids.assert_called_once_with(["node-2"]) + + def test_retrieve_filters_by_threshold(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: + accepted = SimpleNamespace(page_content="keep", metadata={"source": "a"}, score=0.9) + rejected = SimpleNamespace(page_content="drop", metadata={"source": "b"}, score=0.1) + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve" + ) as mock_retrieve: + mock_retrieve.return_value = [accepted, rejected] + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {}) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + + def test_index_list_chunks_high_quality( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="img", metadata={})] + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore" + ) as mock_store_cls, + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls, + ): + processor.index(dataset, dataset_document, ["chunk-1", "chunk-2"]) + + mock_store_cls.return_value.add_documents.assert_called_once() + mock_vector_cls.return_value.create.assert_called_once() + mock_vector_cls.return_value.create_multimodal.assert_called_once() + + def test_index_list_chunks_economy( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + dataset.indexing_technique = "economy" + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object(processor, "_get_content_files", return_value=[]), + patch("core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"), + patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls, + ): + processor.index(dataset, dataset_document, ["chunk-3"]) + + mock_keyword_cls.return_value.add_texts.assert_called_once() + + def test_index_multimodal_structure_handles_files_and_account_lookup( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + chunk_with_files = SimpleNamespace( + content="content-1", + files=[SimpleNamespace(id="file-1", filename="image.png")], + ) + chunk_without_files = SimpleNamespace(content="content-2", files=None) + structure = SimpleNamespace(general_chunks=[chunk_with_files, chunk_without_files]) + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.MultimodalGeneralStructureChunk.model_validate", + return_value=structure, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.AccountService.load_user", + return_value=SimpleNamespace(id="user-1"), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="img", metadata={})] + ) as mock_files, + patch("core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"), + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector"), + ): + processor.index(dataset, dataset_document, {"general_chunks": []}) + + assert mock_files.call_count == 1 + + def test_index_multimodal_structure_requires_valid_account( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + structure = SimpleNamespace(general_chunks=[SimpleNamespace(content="content", files=None)]) + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.MultimodalGeneralStructureChunk.model_validate", + return_value=structure, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.AccountService.load_user", + return_value=None, + ), + ): + with pytest.raises(ValueError, match="Invalid account"): + processor.index(dataset, dataset_document, {"general_chunks": []}) + + def test_format_preview_validates_chunk_shape(self, processor: ParagraphIndexProcessor) -> None: + preview = processor.format_preview(["chunk-1", "chunk-2"]) + assert preview["chunk_structure"] == "text_model" + assert preview["total_segments"] == 2 + + with pytest.raises(ValueError, match="Chunks is not a list"): + processor.format_preview({"not": "a-list"}) + + def test_generate_summary_preview_success_and_failure(self, processor: ParagraphIndexProcessor) -> None: + preview_items = [PreviewDetail(content="chunk-1"), PreviewDetail(content="chunk-2")] + + with patch.object(processor, "generate_summary", return_value=("summary", LLMUsage.empty_usage())): + result = processor.generate_summary_preview( + "tenant-1", preview_items, {"enable": True}, doc_language="English" + ) + assert all(item.summary == "summary" for item in result) + + with patch.object(processor, "generate_summary", side_effect=RuntimeError("summary failed")): + with pytest.raises(ValueError, match="Failed to generate summaries"): + processor.generate_summary_preview("tenant-1", [PreviewDetail(content="chunk-1")], {"enable": True}) + + def test_generate_summary_preview_fallback_without_flask_context(self, processor: ParagraphIndexProcessor) -> None: + preview_items = [PreviewDetail(content="chunk-1")] + fake_current_app = SimpleNamespace(_get_current_object=Mock(side_effect=RuntimeError("no app"))) + + with ( + patch("flask.current_app", fake_current_app), + patch.object(processor, "generate_summary", return_value=("summary", LLMUsage.empty_usage())), + ): + result = processor.generate_summary_preview("tenant-1", preview_items, {"enable": True}) + + assert result[0].summary == "summary" + + def test_generate_summary_preview_timeout( + self, processor: ParagraphIndexProcessor, fake_executor_cls: type + ) -> None: + preview_items = [PreviewDetail(content="chunk-1")] + future = Mock() + executor = fake_executor_cls(future) + + with ( + patch("concurrent.futures.ThreadPoolExecutor", return_value=executor), + patch("concurrent.futures.wait", side_effect=[(set(), {future}), (set(), set())]), + ): + with pytest.raises(ValueError, match="timeout"): + processor.generate_summary_preview("tenant-1", preview_items, {"enable": True}) + + future.cancel.assert_called_once() + + def test_generate_summary_validates_input(self) -> None: + with pytest.raises(ValueError, match="must be enabled"): + ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": False}) + + with pytest.raises(ValueError, match="model_name and model_provider_name"): + ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": True}) + + def test_generate_summary_text_only_flow(self) -> None: + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[]) + model_instance.invoke_llm.return_value = self._llm_result("text summary") + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", + return_value=model_instance, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota", + side_effect=RuntimeError("quota"), + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + summary, usage = ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + document_language="English", + ) + + assert summary == "text summary" + assert isinstance(usage, LLMUsage) + mock_logger.warning.assert_called_with("Failed to deduct quota for summary generation: %s", "quota") + + def test_generate_summary_handles_vision_and_image_conversion(self) -> None: + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.VISION] + ) + model_instance.invoke_llm.return_value = self._llm_result("vision summary") + image_file = SimpleNamespace() + image_content = ImagePromptMessageContent(format="url", mime_type="image/png", url="http://example.com/a.png") + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", + return_value=model_instance, + ), + patch.object( + ParagraphIndexProcessor, "_extract_images_from_segment_attachments", return_value=[image_file] + ), + patch.object(ParagraphIndexProcessor, "_extract_images_from_text", return_value=[]) as mock_extract_text, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content", + return_value=image_content, + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota"), + ): + mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + summary, _ = ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + segment_id="seg-1", + ) + + assert summary == "vision summary" + mock_extract_text.assert_not_called() + + def test_generate_summary_fallbacks_for_prompt_and_result_types(self) -> None: + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.VISION] + ) + model_instance.invoke_llm.return_value = object() + image_file = SimpleNamespace() + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", + return_value=model_instance, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.DEFAULT_GENERATOR_SUMMARY_PROMPT", + "Prompt {missing}", + ), + patch.object(ParagraphIndexProcessor, "_extract_images_from_segment_attachments", return_value=[]), + patch.object(ParagraphIndexProcessor, "_extract_images_from_text", return_value=[image_file]), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content", + side_effect=RuntimeError("bad image"), + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + with pytest.raises(ValueError, match="Expected LLMResult"): + ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + ) + + mock_logger.warning.assert_called_with( + "Failed to convert image file to prompt message content: %s", "bad image" + ) + + def test_extract_images_from_text_handles_patterns_and_build_errors(self) -> None: + text = ( + " " + " " + "" + ) + image_upload = SimpleNamespace( + id="11111111-1111-1111-1111-111111111111", + tenant_id="tenant-1", + name="image.png", + mime_type="image/png", + extension="png", + source_url="", + size=1, + key="key", + ) + non_image_upload = SimpleNamespace( + id="22222222-2222-2222-2222-222222222222", + tenant_id="tenant-1", + name="file.txt", + mime_type="text/plain", + extension="txt", + source_url="", + size=1, + key="key", + ) + query = Mock() + query.where.return_value.all.return_value = [image_upload, non_image_upload] + session = Mock() + session.query.return_value = query + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping", + return_value=SimpleNamespace(id="file-1"), + ) as mock_builder, + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text) + + assert len(files) == 1 + assert mock_builder.call_count == 1 + mock_logger.warning.assert_not_called() + + def test_extract_images_from_text_returns_empty_when_no_matches(self) -> None: + assert ParagraphIndexProcessor._extract_images_from_text("tenant-1", "no images here") == [] + + def test_extract_images_from_text_logs_when_build_fails(self) -> None: + text = "" + image_upload = SimpleNamespace( + id="11111111-1111-1111-1111-111111111111", + tenant_id="tenant-1", + name="image.png", + mime_type="image/png", + extension="png", + source_url="", + size=1, + key="key", + ) + query = Mock() + query.where.return_value.all.return_value = [image_upload] + session = Mock() + session.query.return_value = query + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping", + side_effect=RuntimeError("build failed"), + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text) + + assert files == [] + mock_logger.warning.assert_called_once() + + def test_extract_images_from_segment_attachments(self) -> None: + image_upload = SimpleNamespace( + id="file-1", + name="image", + extension="png", + mime_type="image/png", + source_url="", + size=1, + key="k1", + ) + bad_upload = SimpleNamespace( + id="file-2", + name="broken", + extension=None, + mime_type="image/png", + source_url="", + size=1, + key="k2", + ) + non_image_upload = SimpleNamespace( + id="file-3", + name="text", + extension="txt", + mime_type="text/plain", + source_url="", + size=1, + key="k3", + ) + execute_result = Mock() + execute_result.all.return_value = [(None, image_upload), (None, bad_upload), (None, non_image_upload)] + session = Mock() + session.execute.return_value = execute_result + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1") + + assert len(files) == 1 + mock_logger.warning.assert_called_once() + + def test_extract_images_from_segment_attachments_empty(self) -> None: + execute_result = Mock() + execute_result.all.return_value = [] + session = Mock() + session.execute.return_value = execute_result + + with patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session): + empty_files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1") + + assert empty_files == [] diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py new file mode 100644 index 0000000000..abe40f05d1 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -0,0 +1,523 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor +from core.rag.models.document import AttachmentDocument, ChildDocument, Document +from services.entities.knowledge_entities.knowledge_entities import ParentMode + + +class TestParentChildIndexProcessor: + @pytest.fixture + def processor(self) -> ParentChildIndexProcessor: + return ParentChildIndexProcessor() + + @pytest.fixture + def dataset(self) -> Mock: + dataset = Mock() + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = "high_quality" + dataset.is_multimodal = True + return dataset + + @pytest.fixture + def dataset_document(self) -> Mock: + document = Mock() + document.id = "doc-1" + document.created_by = "user-1" + document.dataset_process_rule_id = None + return document + + def _segmentation(self) -> SimpleNamespace: + return SimpleNamespace(max_tokens=200, chunk_overlap=10, separator="\n") + + def _paragraph_rules(self) -> SimpleNamespace: + return SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + segmentation=self._segmentation(), + subchunk_segmentation=self._segmentation(), + ) + + def _full_doc_rules(self) -> SimpleNamespace: + return SimpleNamespace( + parent_mode=ParentMode.FULL_DOC, segmentation=None, subchunk_segmentation=self._segmentation() + ) + + def test_extract_forwards_automatic_flag(self, processor: ParentChildIndexProcessor) -> None: + extract_setting = Mock() + expected = [Document(page_content="chunk", metadata={})] + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.ExtractProcessor.extract" + ) as mock_extract: + mock_extract.return_value = expected + documents = processor.extract(extract_setting, process_rule_mode="hierarchical") + + assert documents == expected + mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True) + + def test_transform_validates_process_rule(self, processor: ParentChildIndexProcessor) -> None: + with pytest.raises(ValueError, match="No process rule found"): + processor.transform([Document(page_content="text", metadata={})], process_rule=None) + + with pytest.raises(ValueError, match="No rules found in process rule"): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) + + def test_transform_paragraph_requires_segmentation(self, processor: ParentChildIndexProcessor) -> None: + rules = SimpleNamespace(parent_mode=ParentMode.PARAGRAPH, segmentation=None) + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", return_value=rules + ): + with pytest.raises(ValueError, match="No segmentation found in rules"): + processor.transform( + [Document(page_content="text", metadata={})], + process_rule={"mode": "custom", "rules": {"enabled": True}}, + ) + + def test_transform_paragraph_builds_parent_and_child_docs(self, processor: ParentChildIndexProcessor) -> None: + splitter = Mock() + splitter.split_documents.return_value = [ + Document(page_content=".parent", metadata={}), + Document(page_content=" ", metadata={}), + ] + parent_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) + child_docs = [ChildDocument(page_content="child-1", metadata={"dataset_id": "dataset-1"})] + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", + return_value=self._paragraph_rules(), + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.CleanProcessor.clean", + return_value=".parent", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ), + patch.object(processor, "_split_child_nodes", return_value=child_docs), + ): + result = processor.transform( + [parent_document], + process_rule={"mode": "custom", "rules": {"enabled": True}}, + preview=False, + ) + + assert len(result) == 1 + assert result[0].page_content == "parent" + assert result[0].children == child_docs + assert result[0].attachments is not None + + def test_transform_preview_returns_after_ten_parent_chunks(self, processor: ParentChildIndexProcessor) -> None: + splitter = Mock() + splitter.split_documents.return_value = [Document(page_content=f"chunk-{i}", metadata={}) for i in range(10)] + documents = [ + Document(page_content="doc-1", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}), + Document(page_content="doc-2", metadata={"dataset_id": "dataset-1", "document_id": "doc-2"}), + ] + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", + return_value=self._paragraph_rules(), + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.CleanProcessor.clean", + side_effect=lambda text, _: text, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object(processor, "_get_content_files", return_value=[]), + patch.object(processor, "_split_child_nodes", return_value=[]), + ): + result = processor.transform( + documents, + process_rule={"mode": "custom", "rules": {"enabled": True}}, + preview=True, + ) + + assert len(result) == 10 + + def test_transform_full_doc_mode_trims_children_for_preview(self, processor: ParentChildIndexProcessor) -> None: + docs = [ + Document(page_content="first", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}), + Document(page_content="second", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}), + ] + child_docs = [ChildDocument(page_content=f"child-{i}", metadata={}) for i in range(5)] + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", + return_value=self._full_doc_rules(), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ), + patch.object(processor, "_split_child_nodes", return_value=child_docs), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.dify_config.CHILD_CHUNKS_PREVIEW_NUMBER", + 2, + ), + ): + result = processor.transform( + docs, + process_rule={"mode": "hierarchical", "rules": {"enabled": True}}, + preview=True, + ) + + assert len(result) == 1 + assert len(result[0].children or []) == 2 + assert result[0].attachments is not None + + def test_load_creates_vectors_for_child_docs(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + parent_doc = Document( + page_content="parent", + metadata={}, + children=[ + ChildDocument(page_content="child-1", metadata={}), + ChildDocument(page_content="child-2", metadata={}), + ], + ) + multimodal_docs = [AttachmentDocument(page_content="image", metadata={})] + + with patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls: + vector = mock_vector_cls.return_value + processor.load(dataset, [parent_doc], multimodal_documents=multimodal_docs) + + assert vector.create.call_count == 1 + formatted_docs = vector.create.call_args[0][0] + assert len(formatted_docs) == 2 + assert all(isinstance(doc, Document) for doc in formatted_docs) + vector.create_multimodal.assert_called_once_with(multimodal_docs) + + def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + delete_query = Mock() + where_query = Mock() + where_query.delete.return_value = 2 + session = Mock() + session.query.return_value.where.return_value = where_query + + with ( + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + vector = mock_vector_cls.return_value + processor.clean( + dataset, + ["node-1"], + delete_child_chunks=True, + precomputed_child_node_ids=["child-1", "child-2"], + ) + + vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) + where_query.delete.assert_called_once_with(synchronize_session=False) + session.commit.assert_called_once() + + def test_clean_queries_child_ids_when_not_precomputed( + self, processor: ParentChildIndexProcessor, dataset: Mock + ) -> None: + child_query = Mock() + child_query.join.return_value.where.return_value.all.return_value = [("child-1",), (None,), ("child-2",)] + session = Mock() + session.query.return_value = child_query + + with ( + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, ["node-1"], delete_child_chunks=False) + + vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) + + def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + where_query = Mock() + where_query.delete.return_value = 3 + session = Mock() + session.query.return_value.where.return_value = where_query + + with ( + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, None, delete_child_chunks=True) + + vector.delete.assert_called_once() + where_query.delete.assert_called_once_with(synchronize_session=False) + session.commit.assert_called_once() + + def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + segment_query = Mock() + segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + session = Mock() + session.query.return_value = segment_query + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.session_factory.create_session", + return_value=session_ctx, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"), + ): + processor.clean(dataset, ["node-1"], delete_summaries=True, precomputed_child_node_ids=[]) + + mock_summary.assert_called_once_with(dataset, ["seg-1"]) + + def test_clean_deletes_all_summaries_when_node_ids_missing( + self, processor: ParentChildIndexProcessor, dataset: Mock + ) -> None: + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"), + ): + processor.clean(dataset, None, delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, None) + + def test_retrieve_filters_by_score_threshold(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + ok_result = SimpleNamespace(page_content="keep", metadata={"m": 1}, score=0.8) + low_result = SimpleNamespace(page_content="drop", metadata={"m": 2}, score=0.2) + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve" + ) as mock_retrieve: + mock_retrieve.return_value = [ok_result, low_result] + docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, {}) + + assert len(docs) == 1 + assert docs[0].page_content == "keep" + assert docs[0].metadata["score"] == 0.8 + + def test_split_child_nodes_requires_subchunk_segmentation(self, processor: ParentChildIndexProcessor) -> None: + rules = SimpleNamespace(subchunk_segmentation=None) + + with pytest.raises(ValueError, match="No subchunk segmentation found"): + processor._split_child_nodes(Document(page_content="parent", metadata={}), rules, "custom", None) + + def test_split_child_nodes_generates_child_documents(self, processor: ParentChildIndexProcessor) -> None: + rules = SimpleNamespace(subchunk_segmentation=self._segmentation()) + splitter = Mock() + splitter.split_documents.return_value = [ + Document(page_content=".child-1", metadata={}), + Document(page_content=" ", metadata={}), + ] + + with ( + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + ): + child_docs = processor._split_child_nodes( + Document(page_content="parent", metadata={}), rules, "custom", None + ) + + assert len(child_docs) == 1 + assert child_docs[0].page_content == "child-1" + assert child_docs[0].metadata["doc_hash"] == "hash" + + def test_index_creates_process_rule_segments_and_vectors( + self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[ + SimpleNamespace( + parent_content="parent text", + child_contents=["child-1", "child-2"], + files=[SimpleNamespace(id="file-1", filename="image.png")], + ) + ], + ) + dataset_rule = SimpleNamespace(id="rule-1") + session = Mock() + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.DatasetProcessRule", + return_value=dataset_rule, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + side_effect=lambda text: f"hash-{text}", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.DatasetDocumentStore" + ) as mock_store_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + processor.index(dataset, dataset_document, {"parent_child_chunks": []}) + + assert dataset_document.dataset_process_rule_id == "rule-1" + session.add.assert_called_once_with(dataset_rule) + session.flush.assert_called_once() + session.commit.assert_called_once() + mock_store_cls.return_value.add_documents.assert_called_once() + assert mock_vector_cls.return_value.create.call_count == 1 + mock_vector_cls.return_value.create_multimodal.assert_called_once() + + def test_index_uses_content_files_when_files_missing( + self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child"], files=None)], + ) + dataset_rule = SimpleNamespace(id="rule-1") + session = Mock() + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.DatasetProcessRule", + return_value=dataset_rule, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.AccountService.load_user", + return_value=SimpleNamespace(id="user-1"), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ) as mock_files, + patch("core.rag.index_processor.processor.parent_child_index_processor.DatasetDocumentStore"), + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"), + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + processor.index(dataset, dataset_document, {"parent_child_chunks": []}) + + mock_files.assert_called_once() + + def test_index_raises_when_account_missing( + self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child"], files=None)], + ) + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.AccountService.load_user", + return_value=None, + ), + ): + with pytest.raises(ValueError, match="Invalid account"): + processor.index(dataset, dataset_document, {"parent_child_chunks": []}) + + def test_format_preview_returns_parent_child_structure(self, processor: ParentChildIndexProcessor) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child-1", "child-2"])], + ) + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ): + preview = processor.format_preview({"parent_child_chunks": []}) + + assert preview["chunk_structure"] == "hierarchical_model" + assert preview["parent_mode"] == ParentMode.PARAGRAPH + assert preview["total_segments"] == 1 + + def test_generate_summary_preview_sets_summaries(self, processor: ParentChildIndexProcessor) -> None: + preview_texts = [PreviewDetail(content="chunk-1"), PreviewDetail(content="chunk-2")] + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary", + return_value=("summary", None), + ): + result = processor.generate_summary_preview( + "tenant-1", preview_texts, {"enable": True}, doc_language="English" + ) + + assert all(item.summary == "summary" for item in result) + + def test_generate_summary_preview_raises_when_worker_fails(self, processor: ParentChildIndexProcessor) -> None: + preview_texts = [PreviewDetail(content="chunk-1")] + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary", + side_effect=RuntimeError("summary failed"), + ): + with pytest.raises(ValueError, match="Failed to generate summaries"): + processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True}) + + def test_generate_summary_preview_falls_back_without_flask_context( + self, processor: ParentChildIndexProcessor + ) -> None: + preview_texts = [PreviewDetail(content="chunk-1")] + fake_current_app = SimpleNamespace(_get_current_object=Mock(side_effect=RuntimeError("no app"))) + + with ( + patch("flask.current_app", fake_current_app), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary", + return_value=("summary", None), + ), + ): + result = processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True}) + + assert result[0].summary == "summary" + + def test_generate_summary_preview_handles_timeout( + self, processor: ParentChildIndexProcessor, fake_executor_cls: type + ) -> None: + preview_texts = [PreviewDetail(content="chunk-1")] + future = Mock() + executor = fake_executor_cls(future) + + with ( + patch("concurrent.futures.ThreadPoolExecutor", return_value=executor), + patch("concurrent.futures.wait", side_effect=[(set(), {future}), (set(), set())]), + ): + with pytest.raises(ValueError, match="timeout"): + processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True}) + + future.cancel.assert_called_once() diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py new file mode 100644 index 0000000000..8596647ef3 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -0,0 +1,382 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pandas as pd +import pytest +from werkzeug.datastructures import FileStorage + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor +from core.rag.models.document import AttachmentDocument, Document + + +class _ImmediateThread: + def __init__(self, target, args=(), kwargs=None): + self._target = target + self._args = args + self._kwargs = kwargs or {} + + def start(self) -> None: + self._target(*self._args, **self._kwargs) + + def join(self) -> None: + return None + + +class TestQAIndexProcessor: + @pytest.fixture + def processor(self) -> QAIndexProcessor: + return QAIndexProcessor() + + @pytest.fixture + def dataset(self) -> Mock: + dataset = Mock() + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = "high_quality" + dataset.is_multimodal = True + return dataset + + @pytest.fixture + def dataset_document(self) -> Mock: + document = Mock() + document.id = "doc-1" + document.created_by = "user-1" + return document + + @pytest.fixture + def process_rule(self) -> dict: + return { + "mode": "custom", + "rules": {"segmentation": {"max_tokens": 256, "chunk_overlap": 10, "separator": "\n"}}, + } + + def _rules(self) -> SimpleNamespace: + segmentation = SimpleNamespace(max_tokens=256, chunk_overlap=10, separator="\n") + return SimpleNamespace(segmentation=segmentation) + + def test_extract_forwards_automatic_flag(self, processor: QAIndexProcessor) -> None: + extract_setting = Mock() + expected_docs = [Document(page_content="chunk", metadata={})] + + with patch("core.rag.index_processor.processor.qa_index_processor.ExtractProcessor.extract") as mock_extract: + mock_extract.return_value = expected_docs + + docs = processor.extract(extract_setting, process_rule_mode="automatic") + + assert docs == expected_docs + mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True) + + def test_transform_rejects_none_process_rule(self, processor: QAIndexProcessor) -> None: + with pytest.raises(ValueError, match="No process rule found"): + processor.transform([Document(page_content="text", metadata={})], process_rule=None) + + def test_transform_rejects_missing_rules_key(self, processor: QAIndexProcessor) -> None: + with pytest.raises(ValueError, match="No rules found in process rule"): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) + + def test_transform_preview_calls_formatter_once( + self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + ) -> None: + document = Document(page_content="raw text", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) + split_node = Document(page_content=".question", metadata={}) + splitter = Mock() + splitter.split_documents.return_value = [split_node] + + def _append_document(flask_app, tenant_id, document_node, all_qa_documents, document_language): + all_qa_documents.append(Document(page_content="Q1", metadata={"answer": "A1"})) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.Rule.model_validate", return_value=self._rules() + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.qa_index_processor.CleanProcessor.clean", return_value="clean text" + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.remove_leading_symbols", + side_effect=lambda text: text.lstrip("."), + ), + patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format, + patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app, + ): + mock_current_app._get_current_object.return_value = fake_flask_app + result = processor.transform( + [document], + process_rule=process_rule, + preview=True, + tenant_id="tenant-1", + doc_language="English", + ) + + assert len(result) == 1 + assert result[0].metadata["answer"] == "A1" + mock_format.assert_called_once() + + def test_transform_non_preview_uses_thread_batches( + self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + ) -> None: + documents = [ + Document(page_content="doc-1", metadata={"document_id": "doc-1", "dataset_id": "dataset-1"}), + Document(page_content="doc-2", metadata={"document_id": "doc-2", "dataset_id": "dataset-1"}), + ] + split_node = Document(page_content="question", metadata={}) + splitter = Mock() + splitter.split_documents.return_value = [split_node] + + def _append_document(flask_app, tenant_id, document_node, all_qa_documents, document_language): + all_qa_documents.append(Document(page_content=f"Q-{document_node.page_content}", metadata={"answer": "A"})) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.Rule.model_validate", return_value=self._rules() + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.qa_index_processor.CleanProcessor.clean", + side_effect=lambda text, _: text, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.remove_leading_symbols", + side_effect=lambda text: text, + ), + patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format, + patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app, + patch( + "core.rag.index_processor.processor.qa_index_processor.threading.Thread", side_effect=_ImmediateThread + ), + ): + mock_current_app._get_current_object.return_value = fake_flask_app + result = processor.transform(documents, process_rule=process_rule, preview=False, tenant_id="tenant-1") + + assert len(result) == 2 + assert mock_format.call_count == 2 + + def test_format_by_template_validates_file_type(self, processor: QAIndexProcessor) -> None: + not_csv_file = Mock(spec=FileStorage) + not_csv_file.filename = "qa.txt" + + with pytest.raises(ValueError, match="Only CSV files"): + processor.format_by_template(not_csv_file) + + def test_format_by_template_parses_csv_rows(self, processor: QAIndexProcessor) -> None: + csv_file = Mock(spec=FileStorage) + csv_file.filename = "qa.csv" + dataframe = pd.DataFrame([["Q1", "A1"], ["Q2", "A2"]]) + + with patch("core.rag.index_processor.processor.qa_index_processor.pd.read_csv", return_value=dataframe): + docs = processor.format_by_template(csv_file) + + assert [doc.page_content for doc in docs] == ["Q1", "Q2"] + assert [doc.metadata["answer"] for doc in docs] == ["A1", "A2"] + + def test_format_by_template_raises_on_empty_csv(self, processor: QAIndexProcessor) -> None: + csv_file = Mock(spec=FileStorage) + csv_file.filename = "qa.csv" + + with patch("core.rag.index_processor.processor.qa_index_processor.pd.read_csv", return_value=pd.DataFrame()): + with pytest.raises(ValueError, match="empty"): + processor.format_by_template(csv_file) + + def test_format_by_template_raises_on_invalid_csv(self, processor: QAIndexProcessor) -> None: + csv_file = Mock(spec=FileStorage) + csv_file.filename = "qa.csv" + + with patch( + "core.rag.index_processor.processor.qa_index_processor.pd.read_csv", side_effect=Exception("bad csv") + ): + with pytest.raises(ValueError, match="bad csv"): + processor.format_by_template(csv_file) + + def test_load_creates_vectors_for_high_quality_dataset(self, processor: QAIndexProcessor, dataset: Mock) -> None: + docs = [Document(page_content="Q1", metadata={"answer": "A1"})] + multimodal_docs = [AttachmentDocument(page_content="image", metadata={})] + + with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: + vector = mock_vector_cls.return_value + processor.load(dataset, docs, multimodal_documents=multimodal_docs) + + vector.create.assert_called_once_with(docs) + vector.create_multimodal.assert_called_once_with(multimodal_docs) + + def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None: + dataset.indexing_technique = "economy" + docs = [Document(page_content="Q1", metadata={"answer": "A1"})] + + with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: + processor.load(dataset, docs) + + mock_vector_cls.assert_not_called() + + def test_clean_handles_summary_deletion_and_vector_cleanup( + self, processor: QAIndexProcessor, dataset: Mock + ) -> None: + mock_segment = SimpleNamespace(id="seg-1") + mock_query = Mock() + mock_query.filter.return_value.all.return_value = [mock_segment] + mock_session = Mock() + mock_session.query.return_value = mock_query + session_context = MagicMock() + session_context.__enter__.return_value = mock_session + session_context.__exit__.return_value = False + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.session_factory.create_session", + return_value=session_context, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls, + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, ["node-1"], delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, ["seg-1"]) + vector.delete_by_ids.assert_called_once_with(["node-1"]) + + def test_clean_handles_dataset_wide_cleanup(self, processor: QAIndexProcessor, dataset: Mock) -> None: + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls, + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, None, delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, None) + vector.delete.assert_called_once() + + def test_retrieve_filters_by_score_threshold(self, processor: QAIndexProcessor, dataset: Mock) -> None: + result_ok = SimpleNamespace(page_content="accepted", metadata={"source": "a"}, score=0.9) + result_low = SimpleNamespace(page_content="rejected", metadata={"source": "b"}, score=0.1) + + with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve: + mock_retrieve.return_value = [result_ok, result_low] + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {}) + + assert len(docs) == 1 + assert docs[0].page_content == "accepted" + assert docs[0].metadata["score"] == 0.9 + + def test_index_adds_documents_and_vectors_for_high_quality( + self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + qa_chunks = SimpleNamespace( + qa_chunks=[ + SimpleNamespace(question="Q1", answer="A1"), + SimpleNamespace(question="Q2", answer="A2"), + ] + ) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate", + return_value=qa_chunks, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch("core.rag.index_processor.processor.qa_index_processor.DatasetDocumentStore") as mock_store_cls, + patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls, + ): + processor.index(dataset, dataset_document, {"qa_chunks": []}) + + mock_store_cls.return_value.add_documents.assert_called_once() + mock_vector_cls.return_value.create.assert_called_once() + + def test_index_requires_high_quality( + self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + dataset.indexing_technique = "economy" + qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate", + return_value=qa_chunks, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch("core.rag.index_processor.processor.qa_index_processor.DatasetDocumentStore"), + ): + with pytest.raises(ValueError, match="must be high quality"): + processor.index(dataset, dataset_document, {"qa_chunks": []}) + + def test_format_preview_returns_qa_preview(self, processor: QAIndexProcessor) -> None: + qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) + + with patch( + "core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate", + return_value=qa_chunks, + ): + preview = processor.format_preview({"qa_chunks": []}) + + assert preview["chunk_structure"] == "qa_model" + assert preview["total_segments"] == 1 + assert preview["qa_preview"] == [{"question": "Q1", "answer": "A1"}] + + def test_generate_summary_preview_returns_input(self, processor: QAIndexProcessor) -> None: + preview_items = [PreviewDetail(content="Q1")] + assert processor.generate_summary_preview("tenant-1", preview_items, {}) is preview_items + + def test_format_qa_document_ignores_blank_text(self, processor: QAIndexProcessor, fake_flask_app) -> None: + all_qa_documents: list[Document] = [] + blank_document = Document(page_content=" ", metadata={}) + + processor._format_qa_document(fake_flask_app, "tenant-1", blank_document, all_qa_documents, "English") + + assert all_qa_documents == [] + + def test_format_qa_document_builds_question_answer_documents( + self, processor: QAIndexProcessor, fake_flask_app + ) -> None: + all_qa_documents: list[Document] = [] + source_document = Document(page_content="source text", metadata={"origin": "doc-1"}) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document", + return_value="Q1: What is this?\nA1: A test.\nQ2: Why?\nA2: Coverage.", + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + ): + processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English") + + assert len(all_qa_documents) == 2 + assert all_qa_documents[0].page_content == "What is this?" + assert all_qa_documents[0].metadata["answer"] == "A test." + assert all_qa_documents[1].metadata["answer"] == "Coverage." + + def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app) -> None: + all_qa_documents: list[Document] = [] + source_document = Document(page_content="source text", metadata={"origin": "doc-1"}) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document", + side_effect=RuntimeError("llm failure"), + ), + patch("core.rag.index_processor.processor.qa_index_processor.logger") as mock_logger, + ): + processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English") + + assert all_qa_documents == [] + mock_logger.exception.assert_called_once_with("Failed to format qa document") + + def test_format_split_text_extracts_question_answer_pairs(self, processor: QAIndexProcessor) -> None: + parsed = processor._format_split_text("Q1: First?\nA1: One.\nQ2: Second?\nA2: Two.\n") + + assert parsed == [{"question": "First?", "answer": "One."}, {"question": "Second?", "answer": "Two."}] diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py new file mode 100644 index 0000000000..b31bb6eea7 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py @@ -0,0 +1,291 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import httpx +import pytest + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.models.document import AttachmentDocument, Document + + +class _ForwardingBaseIndexProcessor(BaseIndexProcessor): + def extract(self, extract_setting, **kwargs): + return super().extract(extract_setting, **kwargs) + + def transform(self, documents, current_user=None, **kwargs): + return super().transform(documents, current_user=current_user, **kwargs) + + def generate_summary_preview(self, tenant_id, preview_texts, summary_index_setting, doc_language=None): + return super().generate_summary_preview( + tenant_id=tenant_id, + preview_texts=preview_texts, + summary_index_setting=summary_index_setting, + doc_language=doc_language, + ) + + def load(self, dataset, documents, multimodal_documents=None, with_keywords=True, **kwargs): + return super().load( + dataset=dataset, + documents=documents, + multimodal_documents=multimodal_documents, + with_keywords=with_keywords, + **kwargs, + ) + + def clean(self, dataset, node_ids, with_keywords=True, **kwargs): + return super().clean(dataset=dataset, node_ids=node_ids, with_keywords=with_keywords, **kwargs) + + def index(self, dataset, document, chunks): + return super().index(dataset=dataset, document=document, chunks=chunks) + + def format_preview(self, chunks): + return super().format_preview(chunks) + + def retrieve(self, retrieval_method, query, dataset, top_k, score_threshold, reranking_model): + return super().retrieve( + retrieval_method=retrieval_method, + query=query, + dataset=dataset, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) + + +class TestBaseIndexProcessor: + @pytest.fixture + def processor(self) -> _ForwardingBaseIndexProcessor: + return _ForwardingBaseIndexProcessor() + + def test_abstract_methods_raise_not_implemented(self, processor: _ForwardingBaseIndexProcessor) -> None: + with pytest.raises(NotImplementedError): + processor.extract(Mock()) + with pytest.raises(NotImplementedError): + processor.transform([]) + with pytest.raises(NotImplementedError): + processor.generate_summary_preview("tenant", [PreviewDetail(content="c")], {}) + with pytest.raises(NotImplementedError): + processor.load(Mock(), []) + with pytest.raises(NotImplementedError): + processor.clean(Mock(), None) + with pytest.raises(NotImplementedError): + processor.index(Mock(), Mock(), {}) + with pytest.raises(NotImplementedError): + processor.format_preview([]) + with pytest.raises(NotImplementedError): + processor.retrieve("semantic_search", "q", Mock(), 3, 0.5, {}) + + def test_get_splitter_validates_custom_length(self, processor: _ForwardingBaseIndexProcessor) -> None: + with patch( + "core.rag.index_processor.index_processor_base.dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH", 1000 + ): + with pytest.raises(ValueError, match="between 50 and 1000"): + processor._get_splitter("custom", 49, 0, "", None) + with pytest.raises(ValueError, match="between 50 and 1000"): + processor._get_splitter("custom", 1001, 0, "", None) + + def test_get_splitter_custom_mode_uses_fixed_splitter(self, processor: _ForwardingBaseIndexProcessor) -> None: + fixed_splitter = Mock() + with patch( + "core.rag.index_processor.index_processor_base.FixedRecursiveCharacterTextSplitter.from_encoder", + return_value=fixed_splitter, + ) as mock_fixed: + splitter = processor._get_splitter("hierarchical", 120, 10, "\\n\\n", None) + + assert splitter is fixed_splitter + assert mock_fixed.call_args.kwargs["fixed_separator"] == "\n\n" + assert mock_fixed.call_args.kwargs["chunk_size"] == 120 + + def test_get_splitter_automatic_mode_uses_enhance_splitter(self, processor: _ForwardingBaseIndexProcessor) -> None: + auto_splitter = Mock() + with patch( + "core.rag.index_processor.index_processor_base.EnhanceRecursiveCharacterTextSplitter.from_encoder", + return_value=auto_splitter, + ) as mock_enhance: + splitter = processor._get_splitter("automatic", 0, 0, "", None) + + assert splitter is auto_splitter + assert "chunk_size" in mock_enhance.call_args.kwargs + + def test_extract_markdown_images(self, processor: _ForwardingBaseIndexProcessor) -> None: + markdown = "text  and " + images = processor._extract_markdown_images(markdown) + assert images == ["https://a/img.png", "/files/123/file-preview"] + + def test_get_content_files_without_images_returns_empty(self, processor: _ForwardingBaseIndexProcessor) -> None: + document = Document(page_content="no image markdown", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + assert processor._get_content_files(document) == [] + + def test_get_content_files_handles_all_sources_and_duplicates( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + images = [ + "/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview", + "/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview", + "/files/bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb/file-preview", + "/files/tools/cccccccc-cccc-cccc-cccc-cccccccccccc.png", + "https://example.com/remote.png?x=1", + ] + upload_a = SimpleNamespace(id="aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", name="a.png") + upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png") + upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png") + upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png") + db_query = Mock() + db_query.where.return_value.all.return_value = [upload_a, upload_b, upload_tool, upload_remote] + db_session = Mock() + db_session.query.return_value = db_query + + with ( + patch.object(processor, "_extract_markdown_images", return_value=images), + patch.object(processor, "_download_tool_file", return_value="tool-upload-id") as mock_tool_download, + patch.object(processor, "_download_image", return_value="remote-upload-id") as mock_image_download, + patch("core.rag.index_processor.index_processor_base.db.session", db_session), + ): + files = processor._get_content_files(document, current_user=Mock()) + + assert len(files) == 5 + assert all(isinstance(file, AttachmentDocument) for file in files) + assert files[0].metadata["doc_type"] == DocType.IMAGE + assert files[0].metadata["document_id"] == "doc-1" + assert files[0].metadata["dataset_id"] == "ds-1" + assert files[0].metadata["doc_id"] == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + assert files[1].metadata["doc_id"] == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + mock_tool_download.assert_called_once() + mock_image_download.assert_called_once() + + def test_get_content_files_skips_tool_and_remote_download_without_user( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + images = ["/files/tools/cccccccc-cccc-cccc-cccc-cccccccccccc.png", "https://example.com/remote.png"] + + with patch.object(processor, "_extract_markdown_images", return_value=images): + files = processor._get_content_files(document, current_user=None) + + assert files == [] + + def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None: + document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"] + db_query = Mock() + db_query.where.return_value.all.return_value = [] + db_session = Mock() + db_session.query.return_value = db_query + + with ( + patch.object(processor, "_extract_markdown_images", return_value=images), + patch("core.rag.index_processor.index_processor_base.db.session", db_session), + ): + files = processor._get_content_files(document) + + assert files == [] + + def test_download_image_success_with_filename_from_content_disposition( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + response = Mock() + response.headers = { + "Content-Length": "4", + "content-disposition": "attachment; filename=test-image.png", + "content-type": "image/png", + } + response.raise_for_status.return_value = None + response.iter_bytes.return_value = [b"data"] + upload_result = SimpleNamespace(id="upload-id") + + mock_db = Mock() + mock_db.engine = Mock() + + with ( + patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response), + patch("core.rag.index_processor.index_processor_base.db", mock_db), + patch("services.file_service.FileService") as mock_file_service, + ): + mock_file_service.return_value.upload_file.return_value = upload_result + upload_id = processor._download_image("https://example.com/test.png", current_user=Mock()) + + assert upload_id == "upload-id" + mock_file_service.return_value.upload_file.assert_called_once() + + def test_download_image_validates_size_and_empty_content(self, processor: _ForwardingBaseIndexProcessor) -> None: + too_large = Mock() + too_large.headers = {"Content-Length": str(3 * 1024 * 1024), "content-type": "image/png"} + too_large.raise_for_status.return_value = None + + with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=too_large): + assert processor._download_image("https://example.com/too-large.png", current_user=Mock()) is None + + empty = Mock() + empty.headers = {"Content-Length": "0", "content-type": "image/png"} + empty.raise_for_status.return_value = None + empty.iter_bytes.return_value = [] + + with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=empty): + assert processor._download_image("https://example.com/empty.png", current_user=Mock()) is None + + def test_download_image_limits_stream_size(self, processor: _ForwardingBaseIndexProcessor) -> None: + response = Mock() + response.headers = {"content-type": "image/png"} + response.raise_for_status.return_value = None + response.iter_bytes.return_value = [b"a" * (3 * 1024 * 1024)] + + with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response): + assert processor._download_image("https://example.com/big-stream.png", current_user=Mock()) is None + + def test_download_image_handles_timeout_request_and_unexpected_errors( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + request = httpx.Request("GET", "https://example.com/image.png") + + with patch( + "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + side_effect=httpx.TimeoutException("timeout"), + ): + assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None + + with patch( + "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + side_effect=httpx.RequestError("bad request", request=request), + ): + assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None + + with patch( + "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + side_effect=RuntimeError("unexpected"), + ): + assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None + + def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None: + db_query = Mock() + db_query.where.return_value.first.return_value = None + db_session = Mock() + db_session.query.return_value = db_query + + with patch("core.rag.index_processor.index_processor_base.db.session", db_session): + assert processor._download_tool_file("tool-id", current_user=Mock()) is None + + def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None: + tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png") + db_query = Mock() + db_query.where.return_value.first.return_value = tool_file + db_session = Mock() + db_session.query.return_value = db_query + mock_db = Mock() + mock_db.session = db_session + mock_db.engine = Mock() + upload_result = SimpleNamespace(id="upload-id") + + with ( + patch("core.rag.index_processor.index_processor_base.db", mock_db), + patch("core.rag.index_processor.index_processor_base.storage.load_once", return_value=b"blob") as mock_load, + patch("services.file_service.FileService") as mock_file_service, + ): + mock_file_service.return_value.upload_file.return_value = upload_result + result = processor._download_tool_file("tool-id", current_user=Mock()) + + assert result == "upload-id" + mock_load.assert_called_once_with("k1") + mock_file_service.return_value.upload_file.assert_called_once() diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor_factory.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor_factory.py new file mode 100644 index 0000000000..0fc666dbbf --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor_factory.py @@ -0,0 +1,42 @@ +import pytest + +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor +from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor + + +class TestIndexProcessorFactory: + def test_requires_index_type(self) -> None: + factory = IndexProcessorFactory(index_type=None) + + with pytest.raises(ValueError, match="Index type must be specified"): + factory.init_index_processor() + + def test_builds_paragraph_processor(self) -> None: + factory = IndexProcessorFactory(index_type=IndexStructureType.PARAGRAPH_INDEX) + + processor = factory.init_index_processor() + + assert isinstance(processor, ParagraphIndexProcessor) + + def test_builds_qa_processor(self) -> None: + factory = IndexProcessorFactory(index_type=IndexStructureType.QA_INDEX) + + processor = factory.init_index_processor() + + assert isinstance(processor, QAIndexProcessor) + + def test_builds_parent_child_processor(self) -> None: + factory = IndexProcessorFactory(index_type=IndexStructureType.PARENT_CHILD_INDEX) + + processor = factory.init_index_processor() + + assert isinstance(processor, ParentChildIndexProcessor) + + def test_rejects_unsupported_index_type(self) -> None: + factory = IndexProcessorFactory(index_type="unsupported") + + with pytest.raises(ValueError, match="is not supported"): + factory.init_index_processor() diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 0e53482c51..b150d677f1 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -12,13 +12,18 @@ All tests use mocking to avoid external dependencies and ensure fast, reliable e Tests follow the Arrange-Act-Assert pattern for clarity. """ +from operator import itemgetter +from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest from core.model_manager import ModelInstance +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights +from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode @@ -26,7 +31,7 @@ from core.rag.rerank.weight_rerank import WeightRerankRunner from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -def create_mock_model_instance(): +def create_mock_model_instance() -> ModelInstance: """Create a properly configured mock ModelInstance for reranking tests.""" mock_instance = Mock(spec=ModelInstance) # Setup provider_model_bundle chain for check_model_support_vision @@ -59,14 +64,7 @@ class TestRerankModelRunner: @pytest.fixture def mock_model_instance(self): """Create a mock ModelInstance for reranking.""" - mock_instance = Mock(spec=ModelInstance) - # Setup provider_model_bundle chain for check_model_support_vision - mock_instance.provider_model_bundle = Mock() - mock_instance.provider_model_bundle.configuration = Mock() - mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" - mock_instance.provider = "test-provider" - mock_instance.model_name = "test-model" - return mock_instance + return create_mock_model_instance() @pytest.fixture def rerank_runner(self, mock_model_instance): @@ -382,6 +380,206 @@ class TestRerankModelRunner: assert call_kwargs["user"] == "user123" +class _ForwardingBaseRerankRunner(BaseRerankRunner): + def run( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, + ) -> list[Document]: + return super().run( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=top_n, + user=user, + query_type=query_type, + ) + + +class TestBaseRerankRunner: + def test_run_raises_not_implemented(self): + runner = _ForwardingBaseRerankRunner() + + with pytest.raises(NotImplementedError): + runner.run(query="python", documents=[]) + + +class TestRerankModelRunnerMultimodal: + @pytest.fixture + def mock_model_instance(self): + return create_mock_model_instance() + + @pytest.fixture + def rerank_runner(self, mock_model_instance): + return RerankModelRunner(rerank_model_instance=mock_model_instance) + + def test_run_returns_original_documents_for_non_text_query_without_vision_support( + self, rerank_runner, mock_model_instance + ): + documents = [ + Document(page_content="doc", metadata={"doc_id": "doc1"}, provider="dify"), + ] + + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + result = rerank_runner.run(query="image-file-id", documents=documents, query_type=QueryType.IMAGE_QUERY) + + assert result == documents + mock_model_instance.invoke_rerank.assert_not_called() + + def test_run_uses_multimodal_path_when_vision_support_is_enabled(self, rerank_runner): + documents = [ + Document(page_content="doc", metadata={"doc_id": "doc1", "source": "wiki"}, provider="dify"), + ] + rerank_result = RerankResult( + model="rerank-model", + docs=[RerankDocument(index=0, text="doc", score=0.88)], + ) + + with ( + patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm, + patch.object( + rerank_runner, + "fetch_multimodal_rerank", + return_value=(rerank_result, documents), + ) as mock_multimodal, + ): + mock_mm.return_value.check_model_support_vision.return_value = True + result = rerank_runner.run(query="python", documents=documents, query_type=QueryType.TEXT_QUERY) + + mock_multimodal.assert_called_once() + assert len(result) == 1 + assert result[0].metadata["score"] == 0.88 + + def test_fetch_multimodal_rerank_builds_docs_and_calls_text_rerank(self, rerank_runner): + image_doc = Document( + page_content="image-content", + metadata={"doc_id": "img-1", "doc_type": DocType.IMAGE}, + provider="dify", + ) + text_doc = Document( + page_content="text-content", + metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT}, + provider="dify", + ) + external_doc = Document( + page_content="external-content", + metadata={}, + provider="external", + ) + query = Mock() + query.where.return_value.first.return_value = SimpleNamespace(key="image-key") + rerank_result = RerankResult(model="rerank-model", docs=[]) + + with ( + patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), + patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once, + patch.object( + rerank_runner, + "fetch_text_rerank", + return_value=(rerank_result, [image_doc, text_doc, external_doc]), + ) as mock_text_rerank, + ): + result, unique_documents = rerank_runner.fetch_multimodal_rerank( + query="python", + documents=[image_doc, text_doc, external_doc, external_doc], + query_type=QueryType.TEXT_QUERY, + ) + + assert result == rerank_result + assert len(unique_documents) == 3 + mock_load_once.assert_called_once_with("image-key") + text_rerank_call_args = mock_text_rerank.call_args.args + assert len(text_rerank_call_args[1]) == 3 + + def test_fetch_multimodal_rerank_skips_missing_image_upload(self, rerank_runner): + image_doc = Document( + page_content="image-content", + metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE}, + provider="dify", + ) + query = Mock() + query.where.return_value.first.return_value = None + rerank_result = RerankResult(model="rerank-model", docs=[]) + + with ( + patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), + patch.object( + rerank_runner, + "fetch_text_rerank", + return_value=(rerank_result, [image_doc]), + ) as mock_text_rerank, + ): + result, unique_documents = rerank_runner.fetch_multimodal_rerank( + query="python", + documents=[image_doc], + query_type=QueryType.TEXT_QUERY, + ) + + assert result == rerank_result + assert unique_documents == [image_doc] + docs_arg = mock_text_rerank.call_args.args[1] + assert len(docs_arg) == 1 + + def test_fetch_multimodal_rerank_image_query_invokes_multimodal_model(self, rerank_runner, mock_model_instance): + text_doc = Document( + page_content="text-content", + metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT}, + provider="dify", + ) + query_chain = Mock() + query_chain.where.return_value.first.return_value = SimpleNamespace(key="query-image-key") + rerank_result = RerankResult( + model="rerank-model", + docs=[RerankDocument(index=0, text="text-content", score=0.77)], + ) + mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result + + with ( + patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain), + patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"), + ): + result, unique_documents = rerank_runner.fetch_multimodal_rerank( + query="query-upload-id", + documents=[text_doc], + score_threshold=0.2, + top_n=2, + user="user-1", + query_type=QueryType.IMAGE_QUERY, + ) + + assert result == rerank_result + assert unique_documents == [text_doc] + invoke_kwargs = mock_model_instance.invoke_multimodal_rerank.call_args.kwargs + assert invoke_kwargs["query"]["content_type"] == DocType.IMAGE + assert invoke_kwargs["docs"][0]["content"] == "text-content" + assert invoke_kwargs["user"] == "user-1" + + def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner): + query_chain = Mock() + query_chain.where.return_value.first.return_value = None + + with patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain): + with pytest.raises(ValueError, match="Upload file not found for query"): + rerank_runner.fetch_multimodal_rerank( + query="missing-upload-id", + documents=[], + query_type=QueryType.IMAGE_QUERY, + ) + + def test_fetch_multimodal_rerank_rejects_unsupported_query_type(self, rerank_runner): + with pytest.raises(ValueError, match="is not supported"): + rerank_runner.fetch_multimodal_rerank( + query="python", + documents=[], + query_type="unsupported_query_type", + ) + + class TestWeightRerankRunner: """Unit tests for WeightRerankRunner. @@ -512,34 +710,39 @@ class TestWeightRerankRunner: - TF-IDF scores are calculated correctly - Cosine similarity is computed for keyword vectors """ - # Arrange: Create runner runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - - # Mock keyword extraction with specific keywords + keyword_map = { + "python programming": ["python", "programming"], + "Python is a programming language": ["python", "programming", "language"], + "JavaScript for web development": ["javascript", "web"], + "Java object-oriented programming": ["java", "programming"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.side_effect = [ - ["python", "programming"], # query - ["python", "programming", "language"], # doc1 - ["javascript", "web"], # doc2 - ["java", "programming"], # doc3 - ] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance mock_cache_instance = MagicMock() mock_cache_instance.embed_query.return_value = [0.1, 0.2, 0.3, 0.4] mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("python programming", sample_documents_with_vectors) + vector_scores = runner._calculate_cosine( + "tenant123", "python programming", sample_documents_with_vectors, weights_config.vector_setting + ) + expected_scores = { + doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score) + for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores) + } + result = runner.run(query="python programming", documents=sample_documents_with_vectors) - # Assert: Keywords are extracted and scores are calculated - assert len(result) == 3 - # Document 1 should have highest keyword score (matches both query terms) - # Document 3 should have medium score (matches one term) - # Document 2 should have lowest score (matches no terms) + expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)] + assert [doc.metadata["doc_id"] for doc in result] == expected_order + for doc in result: + doc_id = doc.metadata["doc_id"] + assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6) def test_vector_score_calculation( self, @@ -556,30 +759,42 @@ class TestWeightRerankRunner: - Cosine similarity is calculated with document vectors - Vector scores are properly normalized """ - # Arrange: Create runner runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - # Mock keyword extraction + keyword_map = { + "test query": ["test"], + "Python is a programming language": ["python"], + "JavaScript for web development": ["javascript"], + "Java object-oriented programming": ["java"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.return_value = ["test"] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding model mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance - # Mock cache embedding with specific query vector mock_cache_instance = MagicMock() query_vector = [0.2, 0.3, 0.4, 0.5] mock_cache_instance.embed_query.return_value = query_vector mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("test query", sample_documents_with_vectors) + vector_scores = runner._calculate_cosine( + "tenant123", "test query", sample_documents_with_vectors, weights_config.vector_setting + ) + expected_scores = { + doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score) + for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores) + } + result = runner.run(query="test query", documents=sample_documents_with_vectors) - # Assert: Vector scores are calculated - assert len(result) == 3 - # Verify cosine similarity was computed (doc2 vector is closest to query vector) + expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)] + assert [doc.metadata["doc_id"] for doc in result] == expected_order + for doc in result: + doc_id = doc.metadata["doc_id"] + assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6) def test_score_threshold_filtering_weighted( self, @@ -742,28 +957,40 @@ class TestWeightRerankRunner: - Keyword weight (0.4) is applied to keyword scores - Combined score is the sum of weighted components """ - # Arrange: Create runner with known weights runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - # Mock keyword extraction + keyword_map = { + "test": ["test"], + "Python is a programming language": ["python", "language"], + "JavaScript for web development": ["javascript", "web"], + "Java object-oriented programming": ["java", "programming"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.return_value = ["test"] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance mock_cache_instance = MagicMock() mock_cache_instance.embed_query.return_value = [0.1, 0.2, 0.3, 0.4] mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("test", sample_documents_with_vectors) + vector_scores = runner._calculate_cosine( + "tenant123", "test", sample_documents_with_vectors, weights_config.vector_setting + ) + expected_scores = { + doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score) + for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores) + } + result = runner.run(query="test", documents=sample_documents_with_vectors) - # Assert: Scores are combined with weights - # Score = 0.6 * vector_score + 0.4 * keyword_score - assert len(result) == 3 - assert all("score" in doc.metadata for doc in result) + expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)] + assert [doc.metadata["doc_id"] for doc in result] == expected_order + for doc in result: + doc_id = doc.metadata["doc_id"] + assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6) def test_existing_vector_score_in_metadata( self, @@ -778,7 +1005,6 @@ class TestWeightRerankRunner: - If document already has a score in metadata, it's used - Cosine similarity calculation is skipped for such documents """ - # Arrange: Documents with pre-existing scores documents = [ Document( page_content="Content with existing score", @@ -790,24 +1016,29 @@ class TestWeightRerankRunner: runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - # Mock keyword extraction + keyword_map = { + "test": ["test"], + "Content with existing score": ["test"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.return_value = ["test"] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance mock_cache_instance = MagicMock() mock_cache_instance.embed_query.return_value = [0.1, 0.2] mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("test", documents) + vector_scores = runner._calculate_cosine("tenant123", "test", documents, weights_config.vector_setting) + expected_score = 0.6 * vector_scores[0] + 0.4 * query_scores[0] + result = runner.run(query="test", documents=documents) - # Assert: Existing score is used in calculation assert len(result) == 1 - # The final score should incorporate the existing score (0.95) with vector weight (0.6) + assert result[0].metadata["doc_id"] == "doc1" + assert result[0].metadata["score"] == pytest.approx(expected_score, rel=1e-6) class TestRerankRunnerFactory: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index ca08cb0591..b90c4935af 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -1,80 +1,41 @@ -""" -Unit tests for dataset retrieval functionality. - -This module provides comprehensive test coverage for the RetrievalService class, -which is responsible for retrieving relevant documents from datasets using various -search strategies. - -Core Retrieval Mechanisms Tested: -================================== -1. **Vector Search (Semantic Search)** - - Uses embedding vectors to find semantically similar documents - - Supports score thresholds and top-k limiting - - Can filter by document IDs and metadata - -2. **Keyword Search** - - Traditional text-based search using keyword matching - - Handles special characters and query escaping - - Supports document filtering - -3. **Full-Text Search** - - BM25-based full-text search for text matching - - Used in hybrid search scenarios - -4. **Hybrid Search** - - Combines vector and full-text search results - - Implements deduplication to avoid duplicate chunks - - Uses DataPostProcessor for score merging with configurable weights - -5. **Score Merging Algorithms** - - Deduplication based on doc_id - - Retains higher-scoring duplicates - - Supports weighted score combination - -6. **Metadata Filtering** - - Filters documents based on metadata conditions - - Supports document ID filtering - -Test Architecture: -================== -- **Fixtures**: Provide reusable mock objects (datasets, documents, Flask app) -- **Mocking Strategy**: Mock at the method level (embedding_search, keyword_search, etc.) - rather than at the class level to properly simulate the ThreadPoolExecutor behavior -- **Pattern**: All tests follow Arrange-Act-Assert (AAA) pattern -- **Isolation**: Each test is independent and doesn't rely on external state - -Running Tests: -============== - # Run all tests in this module - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py -v - - # Run a specific test class - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py::TestRetrievalService -v - - # Run a specific test - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py::\ -TestRetrievalService::test_vector_search_basic -v - -Notes: -====== -- The RetrievalService uses ThreadPoolExecutor for concurrent search operations -- Tests mock the individual search methods to avoid threading complexity -- All mocked search methods modify the all_documents list in-place -- Score thresholds and top-k limits are enforced by the search methods -""" - +import threading +from contextlib import contextmanager, nullcontext +from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 import pytest +from flask import Flask, current_app +from sqlalchemy import column +from core.app.app_config.entities import ( + Condition as AppCondition, +) +from core.app.app_config.entities import ( + DatasetEntity, + DatasetRetrieveConfigEntity, +) +from core.app.app_config.entities import ( + MetadataFilteringCondition as AppMetadataFilteringCondition, +) +from core.app.app_config.entities import ( + ModelConfig as AppModelConfig, +) +from core.app.app_config.entities import ModelConfig as WorkflowModelConfig +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.agent_entities import PlanningStrategy +from core.entities.model_entities import ModelStatus from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.model_entities import ModelFeature +from dify_graph.nodes.knowledge_retrieval import exc +from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest from models.dataset import Dataset # ==================== Helper Functions ==================== @@ -2013,3 +1974,3091 @@ class TestDocumentModel: assert doc1 == doc2 assert doc1 != doc3 + + +# ==================== Helper Functions ==================== + + +def create_mock_dataset_methods( + dataset_id: str | None = None, + tenant_id: str | None = None, + provider: str = "dify", + indexing_technique: str = "high_quality", + available_document_count: int = 10, +) -> Mock: + """ + Create a mock Dataset object for testing. + + Args: + dataset_id: Unique identifier for the dataset + tenant_id: Tenant ID for the dataset + provider: Provider type ("dify" or "external") + indexing_technique: Indexing technique ("high_quality" or "economy") + available_document_count: Number of available documents + + Returns: + Mock: A properly configured Dataset mock + """ + dataset = Mock(spec=Dataset) + dataset.id = dataset_id or str(uuid4()) + dataset.tenant_id = tenant_id or str(uuid4()) + dataset.name = "test_dataset" + dataset.provider = provider + dataset.indexing_technique = indexing_technique + dataset.available_document_count = available_document_count + dataset.embedding_model = "text-embedding-ada-002" + dataset.embedding_model_provider = "openai" + dataset.retrieval_model = { + "search_method": "semantic_search", + "reranking_enable": False, + "top_k": 4, + "score_threshold_enabled": False, + } + return dataset + + +def create_mock_document_methods( + content: str, + doc_id: str, + score: float = 0.8, + provider: str = "dify", + additional_metadata: dict | None = None, +) -> Document: + """ + Create a mock Document object for testing. + + Args: + content: The text content of the document + doc_id: Unique identifier for the document chunk + score: Relevance score (0.0 to 1.0) + provider: Document provider ("dify" or "external") + additional_metadata: Optional extra metadata fields + + Returns: + Document: A properly structured Document object + """ + metadata = { + "doc_id": doc_id, + "document_id": str(uuid4()), + "dataset_id": str(uuid4()), + "score": score, + } + + if additional_metadata: + metadata.update(additional_metadata) + + return Document( + page_content=content, + metadata=metadata, + provider=provider, + ) + + +# ==================== Test _check_knowledge_rate_limit ==================== + + +class TestCheckKnowledgeRateLimit: + """ + Test suite for _check_knowledge_rate_limit method. + + The _check_knowledge_rate_limit method validates whether a tenant has + exceeded their knowledge retrieval rate limit. This is important for: + - Preventing abuse of the knowledge retrieval system + - Enforcing subscription plan limits + - Tracking usage for billing purposes + + Test Cases: + ============ + 1. Rate limit disabled - no exception raised + 2. Rate limit enabled but not exceeded - no exception raised + 3. Rate limit enabled and exceeded - RateLimitExceededError raised + 4. Redis operations are performed correctly + 5. RateLimitLog is created when limit is exceeded + """ + + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + def test_rate_limit_disabled_no_exception(self, mock_redis, mock_feature_service): + """ + Test that when rate limit is disabled, no exception is raised. + + This test verifies the behavior when the tenant's subscription + does not have rate limiting enabled. + + Verifies: + - FeatureService.get_knowledge_rate_limit is called + - No Redis operations are performed + - No exception is raised + - Retrieval proceeds normally + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit disabled + mock_limit = Mock() + mock_limit.enabled = False + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Act & Assert - should not raise any exception + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify FeatureService was called + mock_feature_service.get_knowledge_rate_limit.assert_called_once_with(tenant_id) + + # Verify no Redis operations were performed + assert not mock_redis.zadd.called + assert not mock_redis.zremrangebyscore.called + assert not mock_redis.zcard.called + + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + @patch("core.rag.retrieval.dataset_retrieval.time") + def test_rate_limit_enabled_not_exceeded(self, mock_time, mock_redis, mock_feature_service, mock_session_factory): + """ + Test that when rate limit is enabled but not exceeded, no exception is raised. + + This test simulates a tenant making requests within their rate limit. + The Redis sorted set stores timestamps of recent requests, and old + requests (older than 60 seconds) are removed. + + Verifies: + - Redis zadd is called to track the request + - Redis zremrangebyscore removes old entries + - Redis zcard returns count within limit + - No exception is raised + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit enabled with limit of 100 requests per minute + mock_limit = Mock() + mock_limit.enabled = True + mock_limit.limit = 100 + mock_limit.subscription_plan = "professional" + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Mock time + current_time = 1234567890000 # Current time in milliseconds + mock_time.time.return_value = current_time / 1000 # Return seconds + mock_time.time.__mul__ = lambda self, x: int(self * x) # Multiply to get milliseconds + + # Mock Redis operations + # zcard returns 50 (within limit of 100) + mock_redis.zcard.return_value = 50 + + # Mock session_factory.create_session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + # Act & Assert - should not raise any exception + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify Redis operations + expected_key = f"rate_limit_{tenant_id}" + mock_redis.zadd.assert_called_once_with(expected_key, {current_time: current_time}) + mock_redis.zremrangebyscore.assert_called_once_with(expected_key, 0, current_time - 60000) + mock_redis.zcard.assert_called_once_with(expected_key) + + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + @patch("core.rag.retrieval.dataset_retrieval.time") + def test_rate_limit_enabled_exceeded_raises_exception( + self, mock_time, mock_redis, mock_feature_service, mock_session_factory + ): + """ + Test that when rate limit is enabled and exceeded, RateLimitExceededError is raised. + + This test simulates a tenant exceeding their rate limit. When the count + of recent requests exceeds the limit, an exception should be raised and + a RateLimitLog should be created. + + Verifies: + - Redis zcard returns count exceeding limit + - RateLimitExceededError is raised with correct message + - RateLimitLog is created in database + - Session operations are performed correctly + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit enabled with limit of 100 requests per minute + mock_limit = Mock() + mock_limit.enabled = True + mock_limit.limit = 100 + mock_limit.subscription_plan = "professional" + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Mock time + current_time = 1234567890000 + mock_time.time.return_value = current_time / 1000 + + # Mock Redis operations - return count exceeding limit + mock_redis.zcard.return_value = 150 # Exceeds limit of 100 + + # Mock session_factory.create_session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + # Act & Assert + with pytest.raises(exc.RateLimitExceededError) as exc_info: + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify exception message + assert "knowledge base request rate limit" in str(exc_info.value) + + # Verify RateLimitLog was created + mock_session.add.assert_called_once() + added_log = mock_session.add.call_args[0][0] + assert added_log.tenant_id == tenant_id + assert added_log.subscription_plan == "professional" + assert added_log.operation == "knowledge" + + +# ==================== Test _get_available_datasets ==================== + + +class TestGetAvailableDatasets: + """ + Test suite for _get_available_datasets method. + + The _get_available_datasets method retrieves datasets that are available + for retrieval. A dataset is considered available if: + - It belongs to the specified tenant + - It's in the list of requested dataset_ids + - It has at least one completed, enabled, non-archived document OR + - It's an external provider dataset + + Note: Due to SQLAlchemy subquery complexity, full testing is done in + integration tests. Unit tests here verify basic behavior. + """ + + def test_method_exists_and_has_correct_signature(self): + """ + Test that the method exists and has the correct signature. + + Verifies: + - Method exists on DatasetRetrieval class + - Accepts tenant_id and dataset_ids parameters + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + + # Assert - method exists + assert hasattr(dataset_retrieval, "_get_available_datasets") + # Assert - method is callable + assert callable(dataset_retrieval._get_available_datasets) + + +# ==================== Test knowledge_retrieval ==================== + + +class TestDatasetRetrievalKnowledgeRetrieval: + """ + Test suite for knowledge_retrieval method. + + The knowledge_retrieval method is the main entry point for retrieving + knowledge from datasets. It orchestrates the entire retrieval process: + 1. Checks rate limits + 2. Gets available datasets + 3. Applies metadata filtering if enabled + 4. Performs retrieval (single or multiple mode) + 5. Formats and returns results + + Test Cases: + ============ + 1. Single mode retrieval + 2. Multiple mode retrieval + 3. Metadata filtering disabled + 4. Metadata filtering automatic + 5. Metadata filtering manual + 6. External documents handling + 7. Dify documents handling + 8. Empty results handling + 9. Rate limit exceeded + 10. No available datasets + """ + + def test_knowledge_retrieval_single_mode_basic(self): + """ + Test knowledge_retrieval in single retrieval mode - basic check. + + Note: Full single mode testing requires complex model mocking and + is better suited for integration tests. This test verifies the + method accepts single mode requests. + + Verifies: + - Method can accept single mode request + - Request parameters are correctly structured + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="single", + model_provider="openai", + model_name="gpt-4", + model_mode="chat", + completion_params={"temperature": 0.7}, + ) + + # Assert - request is properly structured + assert request.retrieval_mode == "single" + assert request.model_provider == "openai" + assert request.model_name == "gpt-4" + assert request.model_mode == "chat" + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + def test_knowledge_retrieval_multiple_mode(self, mock_session_factory, mock_data_processor): + """ + Test knowledge_retrieval in multiple retrieval mode. + + In multiple mode, retrieval is performed across all datasets and + results are combined and reranked. + + Verifies: + - Rate limit is checked + - Available datasets are retrieved + - Multiple retrieval is performed + - Results are combined and reranked + - Results are formatted correctly + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id1 = str(uuid4()) + dataset_id2 = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id1, dataset_id2], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + score_threshold=0.7, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock _check_knowledge_rate_limit + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + # Mock _get_available_datasets + mock_dataset1 = create_mock_dataset_methods(dataset_id=dataset_id1, tenant_id=tenant_id) + mock_dataset2 = create_mock_dataset_methods(dataset_id=dataset_id2, tenant_id=tenant_id) + with patch.object( + dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset1, mock_dataset2] + ): + # Mock get_metadata_filter_condition + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Mock multiple_retrieve to return documents + doc1 = create_mock_document_methods("Python is great", "doc1", score=0.9) + doc2 = create_mock_document_methods("Python is awesome", "doc2", score=0.8) + with patch.object( + dataset_retrieval, "multiple_retrieve", return_value=[doc1, doc2] + ) as mock_multiple_retrieve: + # Mock format_retrieval_documents + mock_record = Mock() + mock_record.segment = Mock() + mock_record.segment.dataset_id = dataset_id1 + mock_record.segment.document_id = str(uuid4()) + mock_record.segment.index_node_hash = "hash123" + mock_record.segment.hit_count = 5 + mock_record.segment.word_count = 100 + mock_record.segment.position = 1 + mock_record.segment.get_sign_content.return_value = "Python is great" + mock_record.segment.answer = None + mock_record.score = 0.9 + mock_record.child_chunks = [] + mock_record.summary = None + mock_record.files = None + + mock_retrieval_service = Mock() + mock_retrieval_service.format_retrieval_documents.return_value = [mock_record] + + with patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService", + return_value=mock_retrieval_service, + ): + # Mock database queries + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + mock_dataset_from_db = Mock() + mock_dataset_from_db.id = dataset_id1 + mock_dataset_from_db.name = "test_dataset" + + mock_document = Mock() + mock_document.id = str(uuid4()) + mock_document.name = "test_doc" + mock_document.data_source_type = "upload_file" + mock_document.doc_metadata = {} + + mock_session.query.return_value.filter.return_value.all.return_value = [ + mock_dataset_from_db + ] + mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter( + [mock_dataset_from_db, mock_document] + ) + + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + mock_multiple_retrieve.assert_called_once() + + def test_knowledge_retrieval_metadata_filtering_disabled(self): + """ + Test knowledge_retrieval with metadata filtering disabled. + + When metadata filtering is disabled, get_metadata_filter_condition is + NOT called (the method checks metadata_filtering_mode != "disabled"). + + Verifies: + - get_metadata_filter_condition is NOT called when mode is "disabled" + - Retrieval proceeds without metadata filters + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + metadata_filtering_mode="disabled", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset_methods(dataset_id=dataset_id, tenant_id=tenant_id) + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + # Mock get_metadata_filter_condition - should NOT be called when disabled + with patch.object( + dataset_retrieval, + "get_metadata_filter_condition", + return_value=(None, None), + ) as mock_get_metadata: + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + # get_metadata_filter_condition should NOT be called when mode is "disabled" + mock_get_metadata.assert_not_called() + + def test_knowledge_retrieval_with_external_documents(self): + """ + Test knowledge_retrieval with external documents. + + External documents come from external knowledge bases and should + be formatted differently than Dify documents. + + Verifies: + - External documents are handled correctly + - Provider is set to "external" + - Metadata includes external-specific fields + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset_methods(dataset_id=dataset_id, tenant_id=tenant_id, provider="external") + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Create external document + external_doc = create_mock_document_methods( + "External knowledge", + "doc1", + score=0.9, + provider="external", + additional_metadata={ + "dataset_id": dataset_id, + "dataset_name": "external_kb", + "document_id": "ext_doc1", + "title": "External Document", + }, + ) + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[external_doc]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + if result: + assert result[0].metadata.data_source_type == "external" + + def test_knowledge_retrieval_empty_results(self): + """ + Test knowledge_retrieval when no documents are found. + + Verifies: + - Empty list is returned + - No errors are raised + - All dependencies are still called + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset_methods(dataset_id=dataset_id, tenant_id=tenant_id) + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Mock multiple_retrieve to return empty list + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert result == [] + + def test_knowledge_retrieval_rate_limit_exceeded(self): + """ + Test knowledge_retrieval when rate limit is exceeded. + + Verifies: + - RateLimitExceededError is raised + - No further processing occurs + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock _check_knowledge_rate_limit to raise exception + with patch.object( + dataset_retrieval, + "_check_knowledge_rate_limit", + side_effect=exc.RateLimitExceededError("Rate limit exceeded"), + ): + # Act & Assert + with pytest.raises(exc.RateLimitExceededError): + dataset_retrieval.knowledge_retrieval(request) + + def test_knowledge_retrieval_no_available_datasets(self): + """ + Test knowledge_retrieval when no datasets are available. + + Verifies: + - Empty list is returned + - No retrieval is attempted + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + # Mock _get_available_datasets to return empty list + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert result == [] + + def test_knowledge_retrieval_handles_multiple_documents_with_different_scores(self): + """ + Test that knowledge_retrieval processes multiple documents with different scores. + + Note: Full sorting and position testing requires complex SQLAlchemy mocking + which is better suited for integration tests. This test verifies documents + with different scores can be created and have their metadata. + + Verifies: + - Documents can be created with different scores + - Score metadata is properly set + """ + # Create documents with different scores + doc1 = create_mock_document_methods("Low score", "doc1", score=0.6) + doc2 = create_mock_document_methods("High score", "doc2", score=0.95) + doc3 = create_mock_document_methods("Medium score", "doc3", score=0.8) + + # Assert - each document has the correct score + assert doc1.metadata["score"] == 0.6 + assert doc2.metadata["score"] == 0.95 + assert doc3.metadata["score"] == 0.8 + + # Assert - documents are correctly sorted (not the retrieval result, just the list) + unsorted = [doc1, doc2, doc3] + sorted_docs = sorted(unsorted, key=lambda d: d.metadata["score"], reverse=True) + assert [d.metadata["score"] for d in sorted_docs] == [0.95, 0.8, 0.6] + + +class TestProcessMetadataFilterFunc: + """ + Comprehensive test suite for process_metadata_filter_func method. + + This test class validates all metadata filtering conditions supported by + the DatasetRetrieval class, including string operations, numeric comparisons, + null checks, and list operations. + + Method Signature: + ================== + def process_metadata_filter_func( + self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list + ) -> list: + + The method builds SQLAlchemy filter expressions by: + 1. Validating value is not None (except for empty/not empty conditions) + 2. Using DatasetDocument.doc_metadata JSON field operations + 3. Adding appropriate SQLAlchemy expressions to the filters list + 4. Returning the updated filters list + + Mocking Strategy: + ================== + - Mock DatasetDocument.doc_metadata to avoid database dependencies + - Verify filter expressions are created correctly + - Test with various data types (str, int, float, list) + """ + + @pytest.fixture + def retrieval(self): + """ + Create a DatasetRetrieval instance for testing. + + Returns: + DatasetRetrieval: Instance to test process_metadata_filter_func + """ + return DatasetRetrieval() + + @pytest.fixture + def mock_doc_metadata(self): + """ + Mock the DatasetDocument.doc_metadata JSON field. + + The method uses DatasetDocument.doc_metadata[metadata_name] to access + JSON fields. We mock this to avoid database dependencies. + + Returns: + Mock: Mocked doc_metadata attribute + """ + mock_metadata_field = MagicMock() + + # Create mock for string access + mock_string_access = MagicMock() + mock_string_access.like = MagicMock() + mock_string_access.notlike = MagicMock() + mock_string_access.__eq__ = MagicMock(return_value=MagicMock()) + mock_string_access.__ne__ = MagicMock(return_value=MagicMock()) + mock_string_access.in_ = MagicMock(return_value=MagicMock()) + + # Create mock for float access (for numeric comparisons) + mock_float_access = MagicMock() + mock_float_access.__eq__ = MagicMock(return_value=MagicMock()) + mock_float_access.__ne__ = MagicMock(return_value=MagicMock()) + mock_float_access.__lt__ = MagicMock(return_value=MagicMock()) + mock_float_access.__gt__ = MagicMock(return_value=MagicMock()) + mock_float_access.__le__ = MagicMock(return_value=MagicMock()) + mock_float_access.__ge__ = MagicMock(return_value=MagicMock()) + + # Create mock for null checks + mock_null_access = MagicMock() + mock_null_access.is_ = MagicMock(return_value=MagicMock()) + mock_null_access.isnot = MagicMock(return_value=MagicMock()) + + # Setup __getitem__ to return appropriate mock based on usage + def getitem_side_effect(name): + if name in ["author", "title", "category"]: + return mock_string_access + elif name in ["year", "price", "rating"]: + return mock_float_access + else: + return mock_string_access + + mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect) + mock_metadata_field.as_string.return_value = mock_string_access + mock_metadata_field.as_float.return_value = mock_float_access + mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_ + mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot + + return mock_metadata_field + + # ==================== String Condition Tests ==================== + + def test_contains_condition_string_value(self, retrieval): + """ + Test 'contains' condition with string value. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses %value% syntax + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "John" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_contains_condition(self, retrieval): + """ + Test 'not contains' condition. + + Verifies: + - Filters list is populated with NOT LIKE expression + - Pattern matching uses %value% syntax with negation + """ + filters = [] + sequence = 0 + condition = "not contains" + metadata_name = "title" + value = "banned" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_start_with_condition(self, retrieval): + """ + Test 'start with' condition. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses value% syntax + """ + filters = [] + sequence = 0 + condition = "start with" + metadata_name = "category" + value = "tech" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_end_with_condition(self, retrieval): + """ + Test 'end with' condition. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses %value syntax + """ + filters = [] + sequence = 0 + condition = "end with" + metadata_name = "filename" + value = ".pdf" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Equality Condition Tests ==================== + + def test_is_condition_with_string_value(self, retrieval): + """ + Test 'is' (=) condition with string value. + + Verifies: + - Filters list is populated with equality expression + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "author" + value = "Jane Doe" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_equals_condition_with_string_value(self, retrieval): + """ + Test '=' condition with string value. + + Verifies: + - Same behavior as 'is' condition + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "=" + metadata_name = "category" + value = "technology" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_condition_with_int_value(self, retrieval): + """ + Test 'is' condition with integer value. + + Verifies: + - Numeric comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "year" + value = 2023 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_condition_with_float_value(self, retrieval): + """ + Test 'is' condition with float value. + + Verifies: + - Numeric comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "price" + value = 19.99 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_not_condition_with_string_value(self, retrieval): + """ + Test 'is not' (≠) condition with string value. + + Verifies: + - Filters list is populated with inequality expression + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "is not" + metadata_name = "author" + value = "Unknown" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_equals_condition(self, retrieval): + """ + Test '≠' condition with string value. + + Verifies: + - Same behavior as 'is not' condition + - Inequality expression is used + """ + filters = [] + sequence = 0 + condition = "≠" + metadata_name = "category" + value = "archived" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_not_condition_with_numeric_value(self, retrieval): + """ + Test 'is not' condition with numeric value. + + Verifies: + - Numeric inequality comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is not" + metadata_name = "year" + value = 2000 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Null Condition Tests ==================== + + def test_empty_condition(self, retrieval): + """ + Test 'empty' condition (null check). + + Verifies: + - Filters list is populated with IS NULL expression + - Value can be None for this condition + """ + filters = [] + sequence = 0 + condition = "empty" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_empty_condition(self, retrieval): + """ + Test 'not empty' condition (not null check). + + Verifies: + - Filters list is populated with IS NOT NULL expression + - Value can be None for this condition + """ + filters = [] + sequence = 0 + condition = "not empty" + metadata_name = "description" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Numeric Comparison Tests ==================== + + def test_before_condition(self, retrieval): + """ + Test 'before' (<) condition. + + Verifies: + - Filters list is populated with less than expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "before" + metadata_name = "year" + value = 2020 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_condition(self, retrieval): + """ + Test '<' condition. + + Verifies: + - Same behavior as 'before' condition + - Less than expression is used + """ + filters = [] + sequence = 0 + condition = "<" + metadata_name = "price" + value = 100.0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_after_condition(self, retrieval): + """ + Test 'after' (>) condition. + + Verifies: + - Filters list is populated with greater than expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "after" + metadata_name = "year" + value = 2020 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_condition(self, retrieval): + """ + Test '>' condition. + + Verifies: + - Same behavior as 'after' condition + - Greater than expression is used + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "rating" + value = 4.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_or_equal_condition_unicode(self, retrieval): + """ + Test '≤' condition. + + Verifies: + - Filters list is populated with less than or equal expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "≤" + metadata_name = "price" + value = 50.0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_or_equal_condition_ascii(self, retrieval): + """ + Test '<=' condition. + + Verifies: + - Same behavior as '≤' condition + - Less than or equal expression is used + """ + filters = [] + sequence = 0 + condition = "<=" + metadata_name = "year" + value = 2023 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_or_equal_condition_unicode(self, retrieval): + """ + Test '≥' condition. + + Verifies: + - Filters list is populated with greater than or equal expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "≥" + metadata_name = "rating" + value = 3.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_or_equal_condition_ascii(self, retrieval): + """ + Test '>=' condition. + + Verifies: + - Same behavior as '≥' condition + - Greater than or equal expression is used + """ + filters = [] + sequence = 0 + condition = ">=" + metadata_name = "year" + value = 2000 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== List/In Condition Tests ==================== + + def test_in_condition_with_comma_separated_string(self, retrieval): + """ + Test 'in' condition with comma-separated string value. + + Verifies: + - String is split into list + - Whitespace is trimmed from each value + - IN expression is created + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "tech, science, AI " + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_list_value(self, retrieval): + """ + Test 'in' condition with list value. + + Verifies: + - List is processed correctly + - None values are filtered out + - IN expression is created with valid values + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "tags" + value = ["python", "javascript", None, "golang"] + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_tuple_value(self, retrieval): + """ + Test 'in' condition with tuple value. + + Verifies: + - Tuple is processed like a list + - IN expression is created + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = ("tech", "science", "ai") + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_empty_string(self, retrieval): + """ + Test 'in' condition with empty string value. + + Verifies: + - Empty string results in literal(False) filter + - No valid values to match + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + # Verify it's a literal(False) expression + # This is a bit tricky to test without access to the actual expression + + def test_in_condition_with_only_whitespace(self, retrieval): + """ + Test 'in' condition with whitespace-only string value. + + Verifies: + - Whitespace-only string results in literal(False) filter + - All values are stripped and filtered out + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = " , , " + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_single_string(self, retrieval): + """ + Test 'in' condition with single non-comma string. + + Verifies: + - Single string is treated as single-item list + - IN expression is created with one value + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "technology" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Edge Case Tests ==================== + + def test_none_value_with_non_empty_condition(self, retrieval): + """ + Test None value with conditions that require value. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values (except empty/not empty) + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 # No filter added + + def test_none_value_with_equals_condition(self, retrieval): + """ + Test None value with 'is' (=) condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_none_value_with_numeric_condition(self, retrieval): + """ + Test None value with numeric comparison condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "year" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_existing_filters_preserved(self, retrieval): + """ + Test that existing filters are preserved. + + Verifies: + - Existing filters in the list are not removed + - New filters are appended to the list + """ + existing_filter = MagicMock() + filters = [existing_filter] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "test" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 2 + assert filters[0] == existing_filter + + def test_multiple_filters_accumulated(self, retrieval): + """ + Test multiple calls to accumulate filters. + + Verifies: + - Each call adds a new filter to the list + - All filters are preserved across calls + """ + filters = [] + + # First filter + retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters) + assert len(filters) == 1 + + # Second filter + retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters) + assert len(filters) == 2 + + # Third filter + retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters) + assert len(filters) == 3 + + def test_unknown_condition(self, retrieval): + """ + Test unknown/unsupported condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for unknown conditions + """ + filters = [] + sequence = 0 + condition = "unknown_condition" + metadata_name = "author" + value = "test" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_empty_string_value_with_contains(self, retrieval): + """ + Test empty string value with 'contains' condition. + + Verifies: + - Filter is added even with empty string + - LIKE expression is created + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_special_characters_in_value(self, retrieval): + """ + Test special characters in value string. + + Verifies: + - Special characters are handled in value + - LIKE expression is created correctly + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "title" + value = "C++ & Python's features" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_zero_value_with_numeric_condition(self, retrieval): + """ + Test zero value with numeric comparison condition. + + Verifies: + - Zero is treated as valid value + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "price" + value = 0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_negative_value_with_numeric_condition(self, retrieval): + """ + Test negative value with numeric comparison condition. + + Verifies: + - Negative numbers are handled correctly + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = "<" + metadata_name = "temperature" + value = -10.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_float_value_with_integer_comparison(self, retrieval): + """ + Test float value with numeric comparison condition. + + Verifies: + - Float values work correctly + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = ">=" + metadata_name = "rating" + value = 4.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + +class TestKnowledgeRetrievalRegression: + @pytest.fixture + def mock_dataset(self) -> Dataset: + dataset = Mock(spec=Dataset) + dataset.id = str(uuid4()) + dataset.tenant_id = str(uuid4()) + dataset.name = "test_dataset" + dataset.indexing_technique = "high_quality" + dataset.provider = "dify" + return dataset + + def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset): + """ + Repro test for current bug: + reranking runs after `with flask_app.app_context():` exits. + `_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`, + so we must assert from that list (not from an outer try/except). + """ + dataset_retrieval = DatasetRetrieval() + flask_app = Flask(__name__) + tenant_id = str(uuid4()) + + # second dataset to ensure dataset_count > 1 reranking branch + secondary_dataset = Mock(spec=Dataset) + secondary_dataset.id = str(uuid4()) + secondary_dataset.provider = "dify" + secondary_dataset.indexing_technique = "high_quality" + + # retriever returns 1 doc into internal list (all_documents_item) + document = Document( + page_content="Context aware doc", + metadata={ + "doc_id": "doc1", + "score": 0.95, + "document_id": str(uuid4()), + "dataset_id": mock_dataset.id, + }, + provider="dify", + ) + + def fake_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.append(document) + + called = {"init": 0, "invoke": 0} + + class ContextRequiredPostProcessor: + def __init__(self, *args, **kwargs): + called["init"] += 1 + # will raise RuntimeError if no Flask app context exists + _ = current_app.name + + def invoke(self, *args, **kwargs): + called["invoke"] += 1 + _ = current_app.name + return kwargs.get("documents") or args[1] + + # output list from _multiple_retrieve_thread + all_documents: list[Document] = [] + + # IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here + thread_exceptions: list[Exception] = [] + + def target(): + with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever): + with patch( + "core.rag.retrieval.dataset_retrieval.DataPostProcessor", + ContextRequiredPostProcessor, + ): + dataset_retrieval._multiple_retrieve_thread( + flask_app=flask_app, + available_datasets=[mock_dataset, secondary_dataset], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={ + "reranking_provider_name": "cohere", + "reranking_model_name": "rerank-v2", + }, + weights=None, + top_k=3, + score_threshold=0.0, + query="test query", + attachment_id=None, + dataset_count=2, # force reranking branch + thread_exceptions=thread_exceptions, # ✅ key + ) + + t = threading.Thread(target=target) + t.start() + t.join() + + # Ensure reranking branch was actually executed + assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run." + + # Current buggy code should record an exception (not raise it) + assert not thread_exceptions, thread_exceptions + + +class _FakeFlaskApp: + def app_context(self): + return nullcontext() + + +class _ImmediateThread: + def __init__(self, target=None, kwargs=None): + self._target = target + self._kwargs = kwargs or {} + self._alive = False + + def start(self) -> None: + self._alive = True + if self._target: + self._target(**self._kwargs) + self._alive = False + + def join(self, timeout=None) -> None: + return None + + def is_alive(self) -> bool: + return self._alive + + +class TestDatasetRetrievalAdditionalHelpers: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_llm_usage_and_record_usage(self, retrieval: DatasetRetrieval) -> None: + empty_usage = retrieval.llm_usage + assert empty_usage.total_tokens == 0 + + retrieval._record_usage(None) + assert retrieval.llm_usage.total_tokens == 0 + + usage_1 = LLMUsage.from_metadata({"prompt_tokens": 2, "completion_tokens": 3, "total_tokens": 5}) + usage_2 = LLMUsage.from_metadata({"prompt_tokens": 4, "completion_tokens": 1, "total_tokens": 5}) + retrieval._record_usage(usage_1) + retrieval._record_usage(usage_2) + assert retrieval.llm_usage.total_tokens == 10 + + def test_replace_metadata_filter_value(self, retrieval: DatasetRetrieval) -> None: + assert retrieval._replace_metadata_filter_value("plain", {}) == "plain" + replaced = retrieval._replace_metadata_filter_value( + "hello {{name}}\n\t{{missing}}", + {"name": "world"}, + ) + assert replaced == "hello world {{missing}}" + + def test_process_metadata_filter_in_with_scalar_fallback(self) -> None: + filters: list = [] + result = DatasetRetrieval.process_metadata_filter_func( + sequence=0, + condition="in", + metadata_name="category", + value=123, + filters=filters, + ) + assert result is filters + assert len(filters) == 1 + + def test_calculate_vector_score(self, retrieval: DatasetRetrieval) -> None: + doc_high = Document(page_content="a", metadata={"score": 0.9}, provider="dify") + doc_low = Document(page_content="b", metadata={"score": 0.2}, provider="dify") + doc_no_meta = Document(page_content="c", metadata={}, provider="dify") + + filtered = retrieval.calculate_vector_score([doc_low, doc_high, doc_no_meta], top_k=1, score_threshold=0.5) + assert len(filtered) == 1 + assert filtered[0].metadata["score"] == 0.9 + + assert retrieval.calculate_vector_score([doc_low], top_k=2, score_threshold=1.0) == [] + + def test_calculate_keyword_score(self, retrieval: DatasetRetrieval) -> None: + documents = [ + Document(page_content="python language", metadata={"doc_id": "1"}, provider="dify"), + Document(page_content="java language", metadata={"doc_id": "2"}, provider="dify"), + ] + keyword_handler = Mock() + keyword_handler.extract_keywords.side_effect = [ + ["python", "language"], + ["python", "language"], + ["java", "language"], + ] + + with patch("core.rag.retrieval.dataset_retrieval.JiebaKeywordTableHandler", return_value=keyword_handler): + ranked = retrieval.calculate_keyword_score("python language", documents, top_k=1) + + assert len(ranked) == 1 + assert "keywords" in ranked[0].metadata + assert ranked[0].metadata["doc_id"] == "1" + + def test_send_trace_task(self, retrieval: DatasetRetrieval) -> None: + trace_manager = Mock() + retrieval.application_generate_entity = SimpleNamespace(trace_manager=trace_manager) + docs = [Document(page_content="d", metadata={}, provider="dify")] + + retrieval._send_trace_task("m1", docs, {"cost": 1}) + trace_manager.add_trace_task.assert_called_once() + + retrieval.application_generate_entity = None + trace_manager.reset_mock() + retrieval._send_trace_task("m1", docs, {"cost": 1}) + trace_manager.add_trace_task.assert_not_called() + + def test_on_query(self, retrieval: DatasetRetrieval) -> None: + with patch("core.rag.retrieval.dataset_retrieval.db.session") as mock_session: + retrieval._on_query( + query=None, + attachment_ids=None, + dataset_ids=["d1"], + app_id="a1", + user_from="web", + user_id="u1", + ) + mock_session.add_all.assert_not_called() + + retrieval._on_query( + query="python", + attachment_ids=["f1"], + dataset_ids=["d1", "d2"], + app_id="a1", + user_from="web", + user_id="u1", + ) + mock_session.add_all.assert_called() + mock_session.commit.assert_called() + + def test_handle_invoke_result(self, retrieval: DatasetRetrieval) -> None: + usage = LLMUsage.empty_usage() + chunk_1 = SimpleNamespace( + model="m1", + prompt_messages=[Mock()], + delta=SimpleNamespace(message=SimpleNamespace(content="hello "), usage=usage), + ) + chunk_2 = SimpleNamespace( + model="m1", + prompt_messages=[Mock()], + delta=SimpleNamespace( + message=SimpleNamespace(content=[SimpleNamespace(data="world")]), + usage=None, + ), + ) + text, returned_usage = retrieval._handle_invoke_result(iter([chunk_1, chunk_2])) + assert text == "hello world" + assert returned_usage == usage + + text_empty, usage_empty = retrieval._handle_invoke_result(iter([])) + assert text_empty == "" + assert usage_empty == LLMUsage.empty_usage() + + def test_get_prompt_template(self, retrieval: DatasetRetrieval) -> None: + model_config_chat = ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="chat", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=["x"], + ) + model_config_completion = ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="completion", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=[], + ) + + with patch("core.rag.retrieval.dataset_retrieval.AdvancedPromptTransform") as mock_prompt_transform: + mock_prompt_transform.return_value.get_prompt.return_value = ["prompt"] + prompt_messages, stop = retrieval._get_prompt_template( + model_config=model_config_chat, + mode="chat", + metadata_fields=["author"], + query="python", + ) + assert prompt_messages == ["prompt"] + assert stop == ["x"] + + with patch( + "core.rag.retrieval.dataset_retrieval.METADATA_FILTER_COMPLETION_PROMPT", + "{input_text} {metadata_fields}", + ): + prompt_messages_completion, stop_completion = retrieval._get_prompt_template( + model_config=model_config_completion, + mode="completion", + metadata_fields=["author"], + query="python", + ) + assert prompt_messages_completion == ["prompt"] + assert stop_completion == [] + + with pytest.raises(ValueError): + retrieval._get_prompt_template( + model_config=model_config_chat, + mode="unknown-mode", + metadata_fields=[], + query="python", + ) + + def test_fetch_model_config_validation_and_success(self, retrieval: DatasetRetrieval) -> None: + with pytest.raises(ValueError, match="single_retrieval_config is required"): + retrieval._fetch_model_config("tenant-1", None) # type: ignore[arg-type] + + model_cfg = AppModelConfig(provider="openai", name="gpt", mode="chat", completion_params={"stop": ["END"]}) + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.provider_model_bundle = Mock() + model_instance.model_type_instance = Mock() + model_instance.model_type_instance.get_model_schema.return_value = Mock() + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelConfigWithCredentialsEntity") as mock_cfg_entity, + ): + mock_manager.return_value.get_model_instance.return_value = model_instance + mock_cfg_entity.return_value = SimpleNamespace( + provider="openai", + model="gpt", + stop=["END"], + parameters={"temperature": 0.1}, + ) + + model_instance.provider_model_bundle.configuration.get_provider_model.return_value = None + with pytest.raises(ValueError, match="not exist"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model = SimpleNamespace(status=ModelStatus.NO_CONFIGURE) + model_instance.provider_model_bundle.configuration.get_provider_model.return_value = provider_model + with pytest.raises(ValueError, match="credentials is not initialized"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model.status = ModelStatus.NO_PERMISSION + with pytest.raises(ValueError, match="currently not support"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model.status = ModelStatus.QUOTA_EXCEEDED + with pytest.raises(ValueError, match="quota exceeded"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model.status = ModelStatus.ACTIVE + bad_mode_cfg = AppModelConfig(provider="openai", name="gpt", mode="chat") + bad_mode_cfg.mode = None # type: ignore[assignment] + with pytest.raises(ValueError, match="LLM mode is required"): + retrieval._fetch_model_config("tenant-1", bad_mode_cfg) + + model_instance.model_type_instance.get_model_schema.return_value = None + with pytest.raises(ValueError, match="not exist"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + model_instance.model_type_instance.get_model_schema.return_value = Mock() + model_cfg_success = AppModelConfig( + provider="openai", + name="gpt", + mode="chat", + completion_params={"temperature": 0.1, "stop": ["END"]}, + ) + _, config = retrieval._fetch_model_config("tenant-1", model_cfg_success) + assert config.provider == "openai" + assert config.model == "gpt" + assert config.stop == ["END"] + assert "stop" not in config.parameters + + def test_automatic_metadata_filter_func(self, retrieval: DatasetRetrieval) -> None: + metadata_field = SimpleNamespace(name="author") + model_instance = Mock() + model_instance.invoke_llm.return_value = iter([Mock()]) + model_config = ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="chat", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=[], + ) + usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}) + session_scalars = Mock() + session_scalars.all.return_value = [metadata_field] + + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=session_scalars), + patch.object(retrieval, "_fetch_model_config", return_value=(model_instance, model_config)), + patch.object(retrieval, "_get_prompt_template", return_value=(["prompt"], [])), + patch.object(retrieval, "_handle_invoke_result", return_value=('{"metadata_map":[]}', usage)), + patch("core.rag.retrieval.dataset_retrieval.parse_and_check_json_markdown") as mock_parse, + patch.object(retrieval, "_record_usage") as mock_record_usage, + ): + mock_parse.return_value = { + "metadata_map": [ + { + "metadata_field_name": "author", + "metadata_field_value": "Alice", + "comparison_operator": "contains", + }, + { + "metadata_field_name": "ignored", + "metadata_field_value": "value", + "comparison_operator": "contains", + }, + ] + } + result = retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + ) + + assert result == [{"metadata_name": "author", "value": "Alice", "condition": "contains"}] + mock_record_usage.assert_called_once_with(usage) + + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=session_scalars), + patch.object(retrieval, "_fetch_model_config", side_effect=RuntimeError("boom")), + ): + with pytest.raises(RuntimeError, match="boom"): + retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + ) + + def test_get_metadata_filter_condition(self, retrieval: DatasetRetrieval) -> None: + db_query = Mock() + db_query.where.return_value = db_query + db_query.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")] + + with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + mapping, condition = retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="disabled", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=None, + inputs={}, + ) + assert mapping is None + assert condition is None + + automatic_filters = [{"condition": "contains", "metadata_name": "author", "value": "Alice"}] + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query), + patch.object(retrieval, "_automatic_metadata_filter_func", return_value=automatic_filters), + ): + mapping, condition = retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="automatic", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=AppMetadataFilteringCondition(logical_operator="or", conditions=[]), + inputs={}, + ) + assert mapping == {"d1": ["doc-1"]} + assert condition is not None + assert condition.logical_operator == "or" + + manual_conditions = AppMetadataFilteringCondition( + logical_operator="and", + conditions=[AppCondition(name="author", comparison_operator="contains", value="{{name}}")], + ) + with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + mapping, condition = retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="manual", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=manual_conditions, + inputs={"name": "Alice"}, + ) + assert mapping == {"d1": ["doc-1"]} + assert condition is not None + assert condition.conditions[0].value == "Alice" + + with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + with pytest.raises(ValueError, match="Invalid metadata filtering mode"): + retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="unsupported", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=None, + inputs={}, + ) + + def test_get_available_datasets(self, retrieval: DatasetRetrieval) -> None: + session = Mock() + subquery_query = Mock() + subquery_query.where.return_value = subquery_query + subquery_query.group_by.return_value = subquery_query + subquery_query.having.return_value = subquery_query + subquery_query.subquery.return_value = SimpleNamespace( + c=SimpleNamespace( + dataset_id=column("dataset_id"), available_document_count=column("available_document_count") + ) + ) + + dataset_query = Mock() + dataset_query.outerjoin.return_value = dataset_query + dataset_query.where.return_value = dataset_query + dataset_query.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")] + session.query.side_effect = [subquery_query, dataset_query] + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with patch("core.rag.retrieval.dataset_retrieval.session_factory.create_session", return_value=session_ctx): + available = retrieval._get_available_datasets("tenant-1", ["d1", "d2"]) + + assert [dataset.id for dataset in available] == ["d1", "d2"] + + def test_check_knowledge_rate_limit(self, retrieval: DatasetRetrieval) -> None: + with ( + patch("core.rag.retrieval.dataset_retrieval.FeatureService.get_knowledge_rate_limit") as mock_limit, + patch("core.rag.retrieval.dataset_retrieval.redis_client") as mock_redis, + patch("core.rag.retrieval.dataset_retrieval.time.time", return_value=100.0), + ): + mock_limit.return_value = SimpleNamespace(enabled=True, limit=2, subscription_plan="pro") + mock_redis.zcard.return_value = 1 + retrieval._check_knowledge_rate_limit("tenant-1") + mock_redis.zadd.assert_called_once() + + session = Mock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with ( + patch("core.rag.retrieval.dataset_retrieval.FeatureService.get_knowledge_rate_limit") as mock_limit, + patch("core.rag.retrieval.dataset_retrieval.redis_client") as mock_redis, + patch("core.rag.retrieval.dataset_retrieval.time.time", return_value=100.0), + patch("core.rag.retrieval.dataset_retrieval.session_factory.create_session", return_value=session_ctx), + ): + mock_limit.return_value = SimpleNamespace(enabled=True, limit=1, subscription_plan="pro") + mock_redis.zcard.return_value = 2 + with pytest.raises(exc.RateLimitExceededError): + retrieval._check_knowledge_rate_limit("tenant-1") + session.add.assert_called_once() + + with patch("core.rag.retrieval.dataset_retrieval.FeatureService.get_knowledge_rate_limit") as mock_limit: + mock_limit.return_value = SimpleNamespace(enabled=False) + retrieval._check_knowledge_rate_limit("tenant-1") + + +def _doc( + provider: str = "dify", + content: str = "content", + score: float = 0.9, + dataset_id: str = "dataset-1", + document_id: str = "document-1", + doc_id: str = "node-1", + extra: dict | None = None, +) -> Document: + metadata = { + "score": score, + "dataset_id": dataset_id, + "document_id": document_id, + "doc_id": doc_id, + } + if extra: + metadata.update(extra) + return Document(page_content=content, metadata=metadata, provider=provider) + + +class _ImmediateThread: + def __init__(self, target=None, kwargs=None): + self._target = target + self._kwargs = kwargs or {} + self._alive = False + + def start(self) -> None: + self._alive = True + if self._target: + self._target(**self._kwargs) + self._alive = False + + def join(self, timeout=None) -> None: + return None + + def is_alive(self) -> bool: + return self._alive + + +class _JoinDrivenThread: + def __init__(self, target=None, kwargs=None): + self._target = target + self._kwargs = kwargs or {} + self._started = False + self._alive = False + + def start(self) -> None: + self._started = True + self._alive = True + + def join(self, timeout=None) -> None: + if self._started and self._alive and self._target: + self._target(**self._kwargs) + self._alive = False + + def is_alive(self) -> bool: + return self._alive + + +@contextmanager +def _timer(): + yield {"cost": 1} + + +class TestKnowledgeRetrievalCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_returns_empty_when_query_missing(self, retrieval: DatasetRetrieval) -> None: + request = KnowledgeRetrievalRequest( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="workflow", + dataset_ids=["d1"], + query=None, + retrieval_mode="multiple", + ) + with ( + patch.object(retrieval, "_check_knowledge_rate_limit"), + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + ): + assert retrieval.knowledge_retrieval(request) == [] + + def test_raises_when_metadata_model_config_missing(self, retrieval: DatasetRetrieval) -> None: + request = KnowledgeRetrievalRequest( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="workflow", + dataset_ids=["d1"], + query="query", + retrieval_mode="multiple", + metadata_filtering_mode="automatic", + metadata_model_config=None, + ) + with ( + patch.object(retrieval, "_check_knowledge_rate_limit"), + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + ): + with pytest.raises(ValueError, match="metadata_model_config is required"): + retrieval.knowledge_retrieval(request) + + @pytest.mark.parametrize( + ("status", "error_cls"), + [ + (ModelStatus.NO_CONFIGURE, "ModelCredentialsNotInitializedError"), + (ModelStatus.NO_PERMISSION, "ModelNotSupportedError"), + (ModelStatus.QUOTA_EXCEEDED, "ModelQuotaExceededError"), + ], + ) + def test_single_mode_raises_for_model_status( + self, + retrieval: DatasetRetrieval, + status: ModelStatus, + error_cls: str, + ) -> None: + request = KnowledgeRetrievalRequest( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="workflow", + dataset_ids=["dataset-1"], + query="python", + retrieval_mode="single", + model_provider="openai", + model_name="gpt-4", + ) + provider_model_bundle = Mock() + provider_model_bundle.configuration.get_provider_model.return_value = SimpleNamespace(status=status) + model_type_instance = Mock() + model_type_instance.get_model_schema.return_value = Mock() + model_instance = SimpleNamespace( + provider_model_bundle=provider_model_bundle, + model_type_instance=model_type_instance, + credentials={}, + ) + with ( + patch.object(retrieval, "_check_knowledge_rate_limit"), + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="dataset-1")]), + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + ): + mock_model_manager.return_value.get_model_instance.return_value = model_instance + with pytest.raises(Exception) as exc_info: + retrieval.knowledge_retrieval(request) + assert error_cls in type(exc_info.value).__name__ + + +class TestRetrieveCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def _build_model_config(self, features: list[ModelFeature] | None = None): + model_type_instance = Mock() + model_type_instance.get_model_schema.return_value = SimpleNamespace(features=features or []) + provider_bundle = SimpleNamespace(model_type_instance=model_type_instance) + return ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt-4", + model_schema=Mock(), + mode="chat", + provider_model_bundle=provider_bundle, + credentials={}, + parameters={}, + stop=[], + ) + + def test_returns_none_when_dataset_ids_empty(self, retrieval: DatasetRetrieval) -> None: + config = DatasetEntity( + dataset_ids=[], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + ), + ) + result = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=self._build_model_config(), + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + assert result == (None, []) + + def test_returns_none_when_model_schema_missing(self, retrieval: DatasetRetrieval) -> None: + config = DatasetEntity( + dataset_ids=["d1"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + ), + ) + model_config = self._build_model_config() + model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None + with patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager: + mock_model_manager.return_value.get_model_instance.return_value = Mock() + result = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + assert result == (None, []) + + def test_single_strategy_with_external_documents(self, retrieval: DatasetRetrieval) -> None: + retrieve_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ) + config = DatasetEntity(dataset_ids=["d1"], retrieve_config=retrieve_config) + model_config = self._build_model_config() + external_doc = _doc( + provider="external", + content="external content", + dataset_id="ext-ds", + document_id="ext-doc", + doc_id="ext-node", + extra={"title": "External", "dataset_name": "External DS"}, + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "single_retrieve", return_value=[external_doc]), + ): + mock_model_manager.return_value.get_model_instance.return_value = Mock() + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + assert context == "external content" + assert files == [] + + def test_multiple_strategy_with_vision_and_source_details(self, retrieval: DatasetRetrieval) -> None: + retrieve_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=4, + score_threshold=0.1, + rerank_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v3"}, + reranking_enabled=True, + metadata_filtering_mode="disabled", + ) + config = DatasetEntity(dataset_ids=["d1"], retrieve_config=retrieve_config) + model_config = self._build_model_config(features=[ModelFeature.TOOL_CALL]) + external_doc = _doc( + provider="external", + content="external body", + score=0.8, + dataset_id="ext-ds", + document_id="ext-doc", + doc_id="ext-node", + extra={"title": "External Title", "dataset_name": "External DS"}, + ) + dify_doc = _doc( + provider="dify", + content="dify body", + score=0.9, + dataset_id="d1", + document_id="doc-1", + doc_id="node-1", + ) + record = SimpleNamespace( + segment=SimpleNamespace( + id="segment-1", + dataset_id="d1", + document_id="doc-1", + tenant_id="tenant-1", + hit_count=3, + word_count=11, + position=1, + index_node_hash="hash-1", + content="segment content", + answer="segment answer", + get_sign_content=lambda: "segment content", + ), + score=0.9, + summary="short summary", + files=None, + ) + dataset_item = SimpleNamespace(id="d1", name="Dataset One") + document_item = SimpleNamespace( + id="doc-1", + name="Document One", + data_source_type="upload_file", + doc_metadata={"lang": "en"}, + ) + upload_file = SimpleNamespace( + id="file-1", + name="image", + extension="png", + mime_type="image/png", + source_url="https://example.com/img.png", + size=123, + key="k1", + ) + execute_attachments = SimpleNamespace(all=lambda: [(SimpleNamespace(), upload_file)]) + execute_docs = SimpleNamespace(scalars=lambda: SimpleNamespace(all=lambda: [document_item])) + execute_datasets = SimpleNamespace(scalars=lambda: SimpleNamespace(all=lambda: [dataset_item])) + hit_callback = Mock() + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "multiple_retrieve", return_value=[external_doc, dify_doc]), + patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService.format_retrieval_documents", + return_value=[record], + ), + patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"), + patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute, + ): + mock_model_manager.return_value.get_model_instance.return_value = Mock() + mock_execute.side_effect = [execute_attachments, execute_docs, execute_datasets] + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.DEBUGGER, + show_retrieve_source=True, + hit_callback=hit_callback, + message_id="m1", + vision_enabled=True, + ) + + assert "short summary" in (context or "") + assert "question:segment content answer:segment answer" in (context or "") + assert len(files or []) == 1 + hit_callback.return_retriever_resource_info.assert_called_once() + + +class TestSingleAndMultipleRetrieveCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_single_retrieve_external_path(self, retrieval: DatasetRetrieval) -> None: + dataset = SimpleNamespace( + id="ds-1", + name="External DS", + description=None, + provider="external", + tenant_id="tenant-1", + retrieval_model={"top_k": 2}, + indexing_technique="high_quality", + ) + app = Flask(__name__) + usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}) + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.ReactMultiDatasetRouter") as mock_router_cls, + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset), + patch( + "core.rag.retrieval.dataset_retrieval.ExternalDatasetService.fetch_external_knowledge_retrieval" + ) as mock_external, + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_on_retrieval_end") as mock_end, + patch.object(retrieval, "_on_query"), + ): + mock_router_cls.return_value.invoke.return_value = ("ds-1", usage) + mock_external.return_value = [ + {"content": "ext result", "metadata": {"k": "v"}, "score": 0.9, "title": "Ext Doc"} + ] + result = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + message_id="m1", + ) + + assert len(result) == 1 + assert result[0].provider == "external" + mock_end.assert_called_once() + assert retrieval.llm_usage.total_tokens == 2 + + def test_single_retrieve_dify_path_and_filters(self, retrieval: DatasetRetrieval) -> None: + dataset = SimpleNamespace( + id="ds-1", + name="Internal DS", + description="dataset desc", + provider="dify", + tenant_id="tenant-1", + indexing_technique="high_quality", + retrieval_model={ + "search_method": "semantic_search", + "reranking_enable": True, + "reranking_model": {"reranking_provider_name": "cohere", "reranking_model_name": "rerank"}, + "reranking_mode": "reranking_model", + "weights": {"vector_setting": {}}, + "top_k": 3, + "score_threshold_enabled": True, + "score_threshold": 0.2, + }, + ) + app = Flask(__name__) + usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 0, "total_tokens": 1}) + result_doc = _doc(provider="dify", score=0.7, dataset_id="ds-1", document_id="doc-1", doc_id="node-1") + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.FunctionCallMultiDatasetRouter") as mock_router_cls, + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset), + patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService.retrieve", return_value=[result_doc] + ) as mock_retrieve, + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_on_retrieval_end"), + patch.object(retrieval, "_on_query"), + ): + mock_router_cls.return_value.invoke.return_value = ("ds-1", usage) + results = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.ROUTER, + metadata_filter_document_ids={"ds-1": ["doc-1"]}, + metadata_condition=SimpleNamespace(), + ) + + assert results == [result_doc] + assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc-1"] + assert retrieval.llm_usage.total_tokens == 1 + + def test_single_retrieve_returns_empty_when_no_dataset_selected(self, retrieval: DatasetRetrieval) -> None: + with patch("core.rag.retrieval.dataset_retrieval.ReactMultiDatasetRouter") as mock_router_cls: + mock_router_cls.return_value.invoke.return_value = (None, LLMUsage.empty_usage()) + results = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[ + SimpleNamespace(id="ds-1", name="DS", description=None), + ], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + ) + assert results == [] + + def test_single_retrieve_respects_metadata_filter_shortcuts(self, retrieval: DatasetRetrieval) -> None: + dataset = SimpleNamespace( + id="ds-1", + name="Internal DS", + description="desc", + provider="dify", + tenant_id="tenant-1", + indexing_technique="high_quality", + retrieval_model={"top_k": 2, "search_method": "semantic_search", "reranking_enable": False}, + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.ReactMultiDatasetRouter") as mock_router_cls, + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset), + ): + mock_router_cls.return_value.invoke.return_value = ("ds-1", LLMUsage.empty_usage()) + no_filter = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + metadata_filter_document_ids=None, + metadata_condition=SimpleNamespace(), + ) + missing_doc_ids = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + metadata_filter_document_ids={"other-ds": ["x"]}, + metadata_condition=None, + ) + assert no_filter == [] + assert missing_doc_ids == [] + + def test_multiple_retrieve_validation_paths(self, retrieval: DatasetRetrieval) -> None: + assert ( + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=[], + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode="reranking_model", + ) + == [] + ) + + mixed = [ + SimpleNamespace(id="d1", indexing_technique="high_quality"), + SimpleNamespace(id="d2", indexing_technique="economy"), + ] + with pytest.raises(ValueError, match="different indexing technique"): + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=mixed, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode="weighted_score", + reranking_enable=False, + ) + + high_quality_mismatch = [ + SimpleNamespace( + id="d1", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ), + SimpleNamespace( + id="d2", + indexing_technique="high_quality", + embedding_model="model-b", + embedding_model_provider="provider-b", + ), + ] + with pytest.raises(ValueError, match="different embedding model"): + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=high_quality_mismatch, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode=RerankMode.WEIGHTED_SCORE, + reranking_enable=True, + ) + + def test_multiple_retrieve_threads_and_dedup(self, retrieval: DatasetRetrieval) -> None: + datasets = [ + SimpleNamespace( + id="d1", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ), + SimpleNamespace( + id="d2", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ), + ] + doc_a = _doc(provider="dify", score=0.8, dataset_id="d1", document_id="doc-1", doc_id="dup") + doc_b = _doc(provider="dify", score=0.7, dataset_id="d2", document_id="doc-2", doc_id="dup") + doc_external = _doc( + provider="external", + score=0.9, + dataset_id="ext-ds", + document_id="ext-doc", + doc_id="ext-node", + extra={"dataset_name": "Ext", "title": "Ext"}, + ) + app = Flask(__name__) + weights = {"vector_setting": {}} + + def fake_multiple_thread(**kwargs): + if kwargs["query"]: + kwargs["all_documents"].extend([doc_a, doc_b]) + if kwargs["attachment_id"]: + kwargs["all_documents"].append(doc_external) + + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.measure_time", _timer), + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_multiple_retrieve_thread", side_effect=fake_multiple_thread), + patch.object(retrieval, "_on_query") as mock_on_query, + patch.object(retrieval, "_on_retrieval_end") as mock_end, + ): + result = retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=datasets, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode=RerankMode.WEIGHTED_SCORE, + reranking_enable=True, + weights=weights, + attachment_ids=["att-1"], + message_id="m1", + ) + + assert len(result) == 2 + assert any(doc.provider == "external" for doc in result) + assert weights["vector_setting"]["embedding_provider_name"] == "provider-a" + assert weights["vector_setting"]["embedding_model_name"] == "model-a" + mock_on_query.assert_called_once() + mock_end.assert_called_once() + + def test_multiple_retrieve_propagates_thread_exception(self, retrieval: DatasetRetrieval) -> None: + datasets = [ + SimpleNamespace( + id="d1", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ) + ] + app = Flask(__name__) + + def failing_thread(**kwargs): + kwargs["thread_exceptions"].append(RuntimeError("thread boom")) + + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.measure_time", _timer), + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_multiple_retrieve_thread", side_effect=failing_thread), + ): + with pytest.raises(RuntimeError, match="thread boom"): + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=datasets, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode="reranking_model", + ) + + +class TestInternalHooksCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_on_retrieval_end_without_dify_documents(self, retrieval: DatasetRetrieval) -> None: + app = Flask(__name__) + with patch.object(retrieval, "_send_trace_task") as mock_trace: + retrieval._on_retrieval_end( + flask_app=app, + documents=[_doc(provider="external")], + message_id="m1", + timer={"cost": 1}, + ) + mock_trace.assert_called_once() + + def test_on_retrieval_end_dify_without_document_ids(self, retrieval: DatasetRetrieval) -> None: + app = Flask(__name__) + doc = Document(page_content="x", metadata={"doc_id": "n1"}, provider="dify") + with ( + patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())), + patch.object(retrieval, "_send_trace_task") as mock_trace, + ): + retrieval._on_retrieval_end(flask_app=app, documents=[doc], message_id="m1", timer={"cost": 1}) + mock_trace.assert_called_once() + + def test_on_retrieval_end_updates_segments_for_text_and_image(self, retrieval: DatasetRetrieval) -> None: + app = Flask(__name__) + docs = [ + _doc(provider="dify", document_id="doc-a", doc_id="idx-a", extra={"doc_type": "text"}), + _doc(provider="dify", document_id="doc-b", doc_id="att-b", extra={"doc_type": DocType.IMAGE}), + _doc(provider="dify", document_id="doc-c", doc_id="idx-c", extra={"doc_type": "text"}), + _doc(provider="dify", document_id="doc-d", doc_id="att-d", extra={"doc_type": DocType.IMAGE}), + ] + dataset_docs = [ + SimpleNamespace(id="doc-a", doc_form=IndexStructureType.PARENT_CHILD_INDEX), + SimpleNamespace(id="doc-b", doc_form=IndexStructureType.PARENT_CHILD_INDEX), + SimpleNamespace(id="doc-c", doc_form="qa_model"), + SimpleNamespace(id="doc-d", doc_form="qa_model"), + ] + child_chunks = [SimpleNamespace(index_node_id="idx-a", segment_id="seg-a")] + segments = [SimpleNamespace(index_node_id="idx-c", id="seg-c")] + bindings = [SimpleNamespace(segment_id="seg-b"), SimpleNamespace(segment_id="seg-d")] + + def _scalars(items): + result = Mock() + result.all.return_value = items + return result + + session = Mock() + session.scalars.side_effect = [ + _scalars(dataset_docs), + _scalars(child_chunks), + _scalars(segments), + _scalars(bindings), + ] + query = Mock() + query.where.return_value = query + session.query.return_value = query + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with ( + patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())), + patch("core.rag.retrieval.dataset_retrieval.Session", return_value=session_ctx), + patch.object(retrieval, "_send_trace_task") as mock_trace, + ): + retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1}) + + query.update.assert_called_once() + session.commit.assert_called_once() + mock_trace.assert_called_once() + + def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None: + flask_app = SimpleNamespace(app_context=lambda: nullcontext()) + all_documents: list[Document] = [] + + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=None): + assert ( + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="d1", + query="python", + top_k=1, + all_documents=all_documents, + ) + == [] + ) + + external_dataset = SimpleNamespace( + id="ext-ds", + name="External", + provider="external", + tenant_id="tenant-1", + retrieval_model={"top_k": 2}, + indexing_technique="high_quality", + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=external_dataset), + patch( + "core.rag.retrieval.dataset_retrieval.ExternalDatasetService.fetch_external_knowledge_retrieval" + ) as mock_external, + ): + mock_external.return_value = [{"content": "e", "metadata": {}, "score": 0.8, "title": "Ext"}] + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="ext-ds", + query="python", + top_k=1, + all_documents=all_documents, + ) + + economy_dataset = SimpleNamespace( + id="eco-ds", + provider="dify", + retrieval_model={"top_k": 1}, + indexing_technique="economy", + ) + high_dataset = SimpleNamespace( + id="hq-ds", + provider="dify", + retrieval_model={ + "search_method": "semantic_search", + "top_k": 4, + "score_threshold": 0.3, + "score_threshold_enabled": True, + "reranking_enable": True, + "reranking_model": {"reranking_provider_name": "x", "reranking_model_name": "y"}, + "reranking_mode": "reranking_model", + "weights": {"vector_setting": {}}, + }, + indexing_technique="high_quality", + ) + with ( + patch( + "core.rag.retrieval.dataset_retrieval.db.session.scalar", side_effect=[economy_dataset, high_dataset] + ), + patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService.retrieve", return_value=[_doc(provider="dify")] + ) as mock_retrieve, + ): + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="eco-ds", + query="python", + top_k=2, + all_documents=all_documents, + ) + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="hq-ds", + query="python", + top_k=2, + all_documents=all_documents, + attachment_ids=["att-1"], + ) + assert mock_retrieve.call_count == 2 + assert len(all_documents) >= 3 + + def test_to_dataset_retriever_tool_paths(self, retrieval: DatasetRetrieval) -> None: + dataset_skip_zero = SimpleNamespace(id="d1", provider="dify", available_document_count=0) + dataset_ok_single = SimpleNamespace( + id="d2", + provider="dify", + available_document_count=2, + retrieval_model={"top_k": 2, "score_threshold_enabled": True, "score_threshold": 0.1}, + ) + single_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ) + with ( + patch( + "core.rag.retrieval.dataset_retrieval.db.session.scalar", + side_effect=[None, dataset_skip_zero, dataset_ok_single], + ), + patch( + "core.tools.utils.dataset_retriever.dataset_retriever_tool.DatasetRetrieverTool.from_dataset", + return_value="single-tool", + ) as mock_single_tool, + ): + single_tools = retrieval.to_dataset_retriever_tool( + tenant_id="tenant-1", + dataset_ids=["missing", "d1", "d2"], + retrieve_config=single_config, + return_resource=True, + invoke_from=InvokeFrom.WEB_APP, + hit_callback=Mock(), + user_id="user-1", + inputs={"k": "v"}, + ) + + assert single_tools == ["single-tool"] + mock_single_tool.assert_called_once() + + multiple_config_missing = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + metadata_filtering_mode="disabled", + reranking_model=None, + ) + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset_ok_single): + with pytest.raises(ValueError, match="Reranking model is required"): + retrieval.to_dataset_retriever_tool( + tenant_id="tenant-1", + dataset_ids=["d2"], + retrieve_config=multiple_config_missing, + return_resource=True, + invoke_from=InvokeFrom.WEB_APP, + hit_callback=Mock(), + user_id="user-1", + inputs={}, + ) + + multiple_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + metadata_filtering_mode="disabled", + top_k=3, + score_threshold=0.2, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v3"}, + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset_ok_single), + patch( + "core.tools.utils.dataset_retriever.dataset_multi_retriever_tool.DatasetMultiRetrieverTool.from_dataset", + return_value="multi-tool", + ) as mock_multi_tool, + ): + multi_tools = retrieval.to_dataset_retriever_tool( + tenant_id="tenant-1", + dataset_ids=["d2"], + retrieve_config=multiple_config, + return_resource=False, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="user-1", + inputs={}, + ) + assert multi_tools == ["multi-tool"] + mock_multi_tool.assert_called_once() + + def test_additional_small_branches(self, retrieval: DatasetRetrieval) -> None: + keyword_handler = Mock() + keyword_handler.extract_keywords.side_effect = [[], []] + doc = Document(page_content="doc", metadata={"doc_id": "1"}, provider="dify") + with patch("core.rag.retrieval.dataset_retrieval.JiebaKeywordTableHandler", return_value=keyword_handler): + ranked = retrieval.calculate_keyword_score("query", [doc], top_k=1) + assert len(ranked) == 1 + assert ranked[0].metadata.get("score") == 0.0 + + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars") as mock_scalars: + mock_scalars.return_value.all.return_value = [] + with pytest.raises(ValueError): + retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="user-1", + metadata_model_config=None, # type: ignore[arg-type] + ) + + session_scalars = Mock() + session_scalars.all.return_value = [SimpleNamespace(name="author")] + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=session_scalars), + patch.object(retrieval, "_fetch_model_config", return_value=(Mock(), Mock())), + patch.object(retrieval, "_get_prompt_template", return_value=(["prompt"], [])), + patch.object(retrieval, "_record_usage"), + ): + model_instance = Mock() + model_instance.invoke_llm.side_effect = RuntimeError("nope") + with patch.object(retrieval, "_fetch_model_config", return_value=(model_instance, Mock())): + assert ( + retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="user-1", + metadata_model_config=WorkflowModelConfig(provider="openai", name="gpt", mode="chat"), + ) + is None + ) + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelMode", return_value=object()), + patch("core.rag.retrieval.dataset_retrieval.AdvancedPromptTransform"), + ): + with pytest.raises(ValueError, match="not support"): + retrieval._get_prompt_template( + model_config=ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="chat", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=[], + ), + mode="chat", + metadata_fields=[], + query="q", + ) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py deleted file mode 100644 index 07d6e51e4b..0000000000 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py +++ /dev/null @@ -1,873 +0,0 @@ -""" -Unit tests for DatasetRetrieval.process_metadata_filter_func. - -This module provides comprehensive test coverage for the process_metadata_filter_func -method in the DatasetRetrieval class, which is responsible for building SQLAlchemy -filter expressions based on metadata filtering conditions. - -Conditions Tested: -================== -1. **String Conditions**: contains, not contains, start with, end with -2. **Equality Conditions**: is / =, is not / ≠ -3. **Null Conditions**: empty, not empty -4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >= -5. **List Conditions**: in -6. **Edge Cases**: None values, different data types (str, int, float) - -Test Architecture: -================== -- Direct instantiation of DatasetRetrieval -- Mocking of DatasetDocument model attributes -- Verification of SQLAlchemy filter expressions -- Follows Arrange-Act-Assert (AAA) pattern - -Running Tests: -============== - # Run all tests in this module - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v - - # Run a specific test - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\ -TestProcessMetadataFilterFunc::test_contains_condition -v -""" - -from unittest.mock import MagicMock - -import pytest - -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval - - -class TestProcessMetadataFilterFunc: - """ - Comprehensive test suite for process_metadata_filter_func method. - - This test class validates all metadata filtering conditions supported by - the DatasetRetrieval class, including string operations, numeric comparisons, - null checks, and list operations. - - Method Signature: - ================== - def process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list - ) -> list: - - The method builds SQLAlchemy filter expressions by: - 1. Validating value is not None (except for empty/not empty conditions) - 2. Using DatasetDocument.doc_metadata JSON field operations - 3. Adding appropriate SQLAlchemy expressions to the filters list - 4. Returning the updated filters list - - Mocking Strategy: - ================== - - Mock DatasetDocument.doc_metadata to avoid database dependencies - - Verify filter expressions are created correctly - - Test with various data types (str, int, float, list) - """ - - @pytest.fixture - def retrieval(self): - """ - Create a DatasetRetrieval instance for testing. - - Returns: - DatasetRetrieval: Instance to test process_metadata_filter_func - """ - return DatasetRetrieval() - - @pytest.fixture - def mock_doc_metadata(self): - """ - Mock the DatasetDocument.doc_metadata JSON field. - - The method uses DatasetDocument.doc_metadata[metadata_name] to access - JSON fields. We mock this to avoid database dependencies. - - Returns: - Mock: Mocked doc_metadata attribute - """ - mock_metadata_field = MagicMock() - - # Create mock for string access - mock_string_access = MagicMock() - mock_string_access.like = MagicMock() - mock_string_access.notlike = MagicMock() - mock_string_access.__eq__ = MagicMock(return_value=MagicMock()) - mock_string_access.__ne__ = MagicMock(return_value=MagicMock()) - mock_string_access.in_ = MagicMock(return_value=MagicMock()) - - # Create mock for float access (for numeric comparisons) - mock_float_access = MagicMock() - mock_float_access.__eq__ = MagicMock(return_value=MagicMock()) - mock_float_access.__ne__ = MagicMock(return_value=MagicMock()) - mock_float_access.__lt__ = MagicMock(return_value=MagicMock()) - mock_float_access.__gt__ = MagicMock(return_value=MagicMock()) - mock_float_access.__le__ = MagicMock(return_value=MagicMock()) - mock_float_access.__ge__ = MagicMock(return_value=MagicMock()) - - # Create mock for null checks - mock_null_access = MagicMock() - mock_null_access.is_ = MagicMock(return_value=MagicMock()) - mock_null_access.isnot = MagicMock(return_value=MagicMock()) - - # Setup __getitem__ to return appropriate mock based on usage - def getitem_side_effect(name): - if name in ["author", "title", "category"]: - return mock_string_access - elif name in ["year", "price", "rating"]: - return mock_float_access - else: - return mock_string_access - - mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect) - mock_metadata_field.as_string.return_value = mock_string_access - mock_metadata_field.as_float.return_value = mock_float_access - mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_ - mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot - - return mock_metadata_field - - # ==================== String Condition Tests ==================== - - def test_contains_condition_string_value(self, retrieval): - """ - Test 'contains' condition with string value. - - Verifies: - - Filters list is populated with LIKE expression - - Pattern matching uses %value% syntax - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = "John" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_not_contains_condition(self, retrieval): - """ - Test 'not contains' condition. - - Verifies: - - Filters list is populated with NOT LIKE expression - - Pattern matching uses %value% syntax with negation - """ - filters = [] - sequence = 0 - condition = "not contains" - metadata_name = "title" - value = "banned" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_start_with_condition(self, retrieval): - """ - Test 'start with' condition. - - Verifies: - - Filters list is populated with LIKE expression - - Pattern matching uses value% syntax - """ - filters = [] - sequence = 0 - condition = "start with" - metadata_name = "category" - value = "tech" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_end_with_condition(self, retrieval): - """ - Test 'end with' condition. - - Verifies: - - Filters list is populated with LIKE expression - - Pattern matching uses %value syntax - """ - filters = [] - sequence = 0 - condition = "end with" - metadata_name = "filename" - value = ".pdf" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Equality Condition Tests ==================== - - def test_is_condition_with_string_value(self, retrieval): - """ - Test 'is' (=) condition with string value. - - Verifies: - - Filters list is populated with equality expression - - String comparison is used - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "author" - value = "Jane Doe" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_equals_condition_with_string_value(self, retrieval): - """ - Test '=' condition with string value. - - Verifies: - - Same behavior as 'is' condition - - String comparison is used - """ - filters = [] - sequence = 0 - condition = "=" - metadata_name = "category" - value = "technology" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_condition_with_int_value(self, retrieval): - """ - Test 'is' condition with integer value. - - Verifies: - - Numeric comparison is used - - as_float() is called on the metadata field - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "year" - value = 2023 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_condition_with_float_value(self, retrieval): - """ - Test 'is' condition with float value. - - Verifies: - - Numeric comparison is used - - as_float() is called on the metadata field - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "price" - value = 19.99 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_not_condition_with_string_value(self, retrieval): - """ - Test 'is not' (≠) condition with string value. - - Verifies: - - Filters list is populated with inequality expression - - String comparison is used - """ - filters = [] - sequence = 0 - condition = "is not" - metadata_name = "author" - value = "Unknown" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_not_equals_condition(self, retrieval): - """ - Test '≠' condition with string value. - - Verifies: - - Same behavior as 'is not' condition - - Inequality expression is used - """ - filters = [] - sequence = 0 - condition = "≠" - metadata_name = "category" - value = "archived" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_not_condition_with_numeric_value(self, retrieval): - """ - Test 'is not' condition with numeric value. - - Verifies: - - Numeric inequality comparison is used - - as_float() is called on the metadata field - """ - filters = [] - sequence = 0 - condition = "is not" - metadata_name = "year" - value = 2000 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Null Condition Tests ==================== - - def test_empty_condition(self, retrieval): - """ - Test 'empty' condition (null check). - - Verifies: - - Filters list is populated with IS NULL expression - - Value can be None for this condition - """ - filters = [] - sequence = 0 - condition = "empty" - metadata_name = "author" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_not_empty_condition(self, retrieval): - """ - Test 'not empty' condition (not null check). - - Verifies: - - Filters list is populated with IS NOT NULL expression - - Value can be None for this condition - """ - filters = [] - sequence = 0 - condition = "not empty" - metadata_name = "description" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Numeric Comparison Tests ==================== - - def test_before_condition(self, retrieval): - """ - Test 'before' (<) condition. - - Verifies: - - Filters list is populated with less than expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "before" - metadata_name = "year" - value = 2020 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_less_than_condition(self, retrieval): - """ - Test '<' condition. - - Verifies: - - Same behavior as 'before' condition - - Less than expression is used - """ - filters = [] - sequence = 0 - condition = "<" - metadata_name = "price" - value = 100.0 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_after_condition(self, retrieval): - """ - Test 'after' (>) condition. - - Verifies: - - Filters list is populated with greater than expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "after" - metadata_name = "year" - value = 2020 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_greater_than_condition(self, retrieval): - """ - Test '>' condition. - - Verifies: - - Same behavior as 'after' condition - - Greater than expression is used - """ - filters = [] - sequence = 0 - condition = ">" - metadata_name = "rating" - value = 4.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_less_than_or_equal_condition_unicode(self, retrieval): - """ - Test '≤' condition. - - Verifies: - - Filters list is populated with less than or equal expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "≤" - metadata_name = "price" - value = 50.0 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_less_than_or_equal_condition_ascii(self, retrieval): - """ - Test '<=' condition. - - Verifies: - - Same behavior as '≤' condition - - Less than or equal expression is used - """ - filters = [] - sequence = 0 - condition = "<=" - metadata_name = "year" - value = 2023 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_greater_than_or_equal_condition_unicode(self, retrieval): - """ - Test '≥' condition. - - Verifies: - - Filters list is populated with greater than or equal expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "≥" - metadata_name = "rating" - value = 3.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_greater_than_or_equal_condition_ascii(self, retrieval): - """ - Test '>=' condition. - - Verifies: - - Same behavior as '≥' condition - - Greater than or equal expression is used - """ - filters = [] - sequence = 0 - condition = ">=" - metadata_name = "year" - value = 2000 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== List/In Condition Tests ==================== - - def test_in_condition_with_comma_separated_string(self, retrieval): - """ - Test 'in' condition with comma-separated string value. - - Verifies: - - String is split into list - - Whitespace is trimmed from each value - - IN expression is created - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = "tech, science, AI " - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_list_value(self, retrieval): - """ - Test 'in' condition with list value. - - Verifies: - - List is processed correctly - - None values are filtered out - - IN expression is created with valid values - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "tags" - value = ["python", "javascript", None, "golang"] - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_tuple_value(self, retrieval): - """ - Test 'in' condition with tuple value. - - Verifies: - - Tuple is processed like a list - - IN expression is created - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = ("tech", "science", "ai") - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_empty_string(self, retrieval): - """ - Test 'in' condition with empty string value. - - Verifies: - - Empty string results in literal(False) filter - - No valid values to match - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = "" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - # Verify it's a literal(False) expression - # This is a bit tricky to test without access to the actual expression - - def test_in_condition_with_only_whitespace(self, retrieval): - """ - Test 'in' condition with whitespace-only string value. - - Verifies: - - Whitespace-only string results in literal(False) filter - - All values are stripped and filtered out - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = " , , " - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_single_string(self, retrieval): - """ - Test 'in' condition with single non-comma string. - - Verifies: - - Single string is treated as single-item list - - IN expression is created with one value - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = "technology" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Edge Case Tests ==================== - - def test_none_value_with_non_empty_condition(self, retrieval): - """ - Test None value with conditions that require value. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for None values (except empty/not empty) - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 # No filter added - - def test_none_value_with_equals_condition(self, retrieval): - """ - Test None value with 'is' (=) condition. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for None values - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "author" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 - - def test_none_value_with_numeric_condition(self, retrieval): - """ - Test None value with numeric comparison condition. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for None values - """ - filters = [] - sequence = 0 - condition = ">" - metadata_name = "year" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 - - def test_existing_filters_preserved(self, retrieval): - """ - Test that existing filters are preserved. - - Verifies: - - Existing filters in the list are not removed - - New filters are appended to the list - """ - existing_filter = MagicMock() - filters = [existing_filter] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = "test" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 2 - assert filters[0] == existing_filter - - def test_multiple_filters_accumulated(self, retrieval): - """ - Test multiple calls to accumulate filters. - - Verifies: - - Each call adds a new filter to the list - - All filters are preserved across calls - """ - filters = [] - - # First filter - retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters) - assert len(filters) == 1 - - # Second filter - retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters) - assert len(filters) == 2 - - # Third filter - retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters) - assert len(filters) == 3 - - def test_unknown_condition(self, retrieval): - """ - Test unknown/unsupported condition. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for unknown conditions - """ - filters = [] - sequence = 0 - condition = "unknown_condition" - metadata_name = "author" - value = "test" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 - - def test_empty_string_value_with_contains(self, retrieval): - """ - Test empty string value with 'contains' condition. - - Verifies: - - Filter is added even with empty string - - LIKE expression is created - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = "" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_special_characters_in_value(self, retrieval): - """ - Test special characters in value string. - - Verifies: - - Special characters are handled in value - - LIKE expression is created correctly - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "title" - value = "C++ & Python's features" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_zero_value_with_numeric_condition(self, retrieval): - """ - Test zero value with numeric comparison condition. - - Verifies: - - Zero is treated as valid value - - Numeric comparison is performed - """ - filters = [] - sequence = 0 - condition = ">" - metadata_name = "price" - value = 0 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_negative_value_with_numeric_condition(self, retrieval): - """ - Test negative value with numeric comparison condition. - - Verifies: - - Negative numbers are handled correctly - - Numeric comparison is performed - """ - filters = [] - sequence = 0 - condition = "<" - metadata_name = "temperature" - value = -10.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_float_value_with_integer_comparison(self, retrieval): - """ - Test float value with numeric comparison condition. - - Verifies: - - Float values work correctly - - Numeric comparison is performed - """ - filters = [] - sequence = 0 - condition = ">=" - metadata_name = "rating" - value = 4.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 diff --git a/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py deleted file mode 100644 index 5f461d53ae..0000000000 --- a/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py +++ /dev/null @@ -1,113 +0,0 @@ -import threading -from unittest.mock import Mock, patch -from uuid import uuid4 - -import pytest -from flask import Flask, current_app - -from core.rag.models.document import Document -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from models.dataset import Dataset - - -class TestRetrievalService: - @pytest.fixture - def mock_dataset(self) -> Dataset: - dataset = Mock(spec=Dataset) - dataset.id = str(uuid4()) - dataset.tenant_id = str(uuid4()) - dataset.name = "test_dataset" - dataset.indexing_technique = "high_quality" - dataset.provider = "dify" - return dataset - - def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset): - """ - Repro test for current bug: - reranking runs after `with flask_app.app_context():` exits. - `_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`, - so we must assert from that list (not from an outer try/except). - """ - dataset_retrieval = DatasetRetrieval() - flask_app = Flask(__name__) - tenant_id = str(uuid4()) - - # second dataset to ensure dataset_count > 1 reranking branch - secondary_dataset = Mock(spec=Dataset) - secondary_dataset.id = str(uuid4()) - secondary_dataset.provider = "dify" - secondary_dataset.indexing_technique = "high_quality" - - # retriever returns 1 doc into internal list (all_documents_item) - document = Document( - page_content="Context aware doc", - metadata={ - "doc_id": "doc1", - "score": 0.95, - "document_id": str(uuid4()), - "dataset_id": mock_dataset.id, - }, - provider="dify", - ) - - def fake_retriever( - flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids - ): - all_documents.append(document) - - called = {"init": 0, "invoke": 0} - - class ContextRequiredPostProcessor: - def __init__(self, *args, **kwargs): - called["init"] += 1 - # will raise RuntimeError if no Flask app context exists - _ = current_app.name - - def invoke(self, *args, **kwargs): - called["invoke"] += 1 - _ = current_app.name - return kwargs.get("documents") or args[1] - - # output list from _multiple_retrieve_thread - all_documents: list[Document] = [] - - # IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here - thread_exceptions: list[Exception] = [] - - def target(): - with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever): - with patch( - "core.rag.retrieval.dataset_retrieval.DataPostProcessor", - ContextRequiredPostProcessor, - ): - dataset_retrieval._multiple_retrieve_thread( - flask_app=flask_app, - available_datasets=[mock_dataset, secondary_dataset], - metadata_condition=None, - metadata_filter_document_ids=None, - all_documents=all_documents, - tenant_id=tenant_id, - reranking_enable=True, - reranking_mode="reranking_model", - reranking_model={ - "reranking_provider_name": "cohere", - "reranking_model_name": "rerank-v2", - }, - weights=None, - top_k=3, - score_threshold=0.0, - query="test query", - attachment_id=None, - dataset_count=2, # force reranking branch - thread_exceptions=thread_exceptions, # ✅ key - ) - - t = threading.Thread(target=target) - t.start() - t.join() - - # Ensure reranking branch was actually executed - assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run." - - # Current buggy code should record an exception (not raise it) - assert not thread_exceptions, thread_exceptions diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py new file mode 100644 index 0000000000..cfa9094e12 --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -0,0 +1,100 @@ +from unittest.mock import Mock + +from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter +from dify_graph.model_runtime.entities.llm_entities import LLMUsage + + +class TestFunctionCallMultiDatasetRouter: + def test_invoke_returns_none_when_no_tools(self) -> None: + router = FunctionCallMultiDatasetRouter() + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[], + model_config=Mock(), + model_instance=Mock(), + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_single_tool_directly(self) -> None: + router = FunctionCallMultiDatasetRouter() + tool = Mock() + tool.name = "dataset-1" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool], + model_config=Mock(), + model_instance=Mock(), + ) + + assert dataset_id == "dataset-1" + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_tool_from_model_response(self) -> None: + router = FunctionCallMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + usage = LLMUsage.empty_usage() + response = Mock() + response.usage = usage + response.message.tool_calls = [Mock(function=Mock())] + response.message.tool_calls[0].function.name = "dataset-2" + model_instance = Mock() + model_instance.invoke_llm.return_value = response + + dataset_id, returned_usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=model_instance, + ) + + assert dataset_id == "dataset-2" + assert returned_usage == usage + model_instance.invoke_llm.assert_called_once() + + def test_invoke_returns_none_when_no_tool_calls(self) -> None: + router = FunctionCallMultiDatasetRouter() + response = Mock() + response.usage = LLMUsage.empty_usage() + response.message.tool_calls = [] + model_instance = Mock() + model_instance.invoke_llm.return_value = response + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=model_instance, + ) + + assert dataset_id is None + assert usage == response.usage + + def test_invoke_returns_empty_usage_when_model_raises(self) -> None: + router = FunctionCallMultiDatasetRouter() + model_instance = Mock() + model_instance.invoke_llm.side_effect = RuntimeError("boom") + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=model_instance, + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py new file mode 100644 index 0000000000..e429563739 --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -0,0 +1,252 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole + + +class TestReactMultiDatasetRouter: + def test_invoke_returns_none_when_no_tools(self) -> None: + router = ReactMultiDatasetRouter() + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_single_tool_directly(self) -> None: + router = ReactMultiDatasetRouter() + tool = Mock() + tool.name = "dataset-1" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + assert dataset_id == "dataset-1" + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_tool_from_react_invoke(self) -> None: + router = ReactMultiDatasetRouter() + usage = LLMUsage.empty_usage() + tool_1 = Mock(name="dataset-1") + tool_1.name = "dataset-1" + tool_2 = Mock(name="dataset-2") + tool_2.name = "dataset-2" + + with patch.object(router, "_react_invoke", return_value=("dataset-2", usage)) as mock_react: + dataset_id, returned_usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + mock_react.assert_called_once() + assert dataset_id == "dataset-2" + assert returned_usage == usage + + def test_invoke_handles_react_invoke_errors(self) -> None: + router = ReactMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + + with patch.object(router, "_react_invoke", side_effect=RuntimeError("boom")): + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() + + def test_react_invoke_returns_action_tool(self) -> None: + router = ReactMultiDatasetRouter() + model_config = Mock() + model_config.mode = "chat" + model_config.parameters = {"temperature": 0.1} + usage = LLMUsage.empty_usage() + tools = [Mock(name="dataset-1"), Mock(name="dataset-2")] + tools[0].name = "dataset-1" + tools[0].description = "desc" + tools[1].name = "dataset-2" + tools[1].description = "desc" + + with ( + patch.object(router, "create_chat_prompt", return_value=[Mock()]) as mock_chat_prompt, + patch( + "core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform" + ) as mock_prompt_transform, + patch.object(router, "_invoke_llm", return_value=('{"action":"dataset-2","action_input":{}}', usage)), + patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls, + ): + mock_prompt_transform.return_value.get_prompt.return_value = [Mock()] + mock_parser_cls.return_value.parse.return_value = ReactAction("dataset-2", {}, "log") + + dataset_id, returned_usage = router._react_invoke( + query="python", + model_config=model_config, + model_instance=Mock(), + tools=tools, + user_id="u1", + tenant_id="t1", + ) + + mock_chat_prompt.assert_called_once() + assert dataset_id == "dataset-2" + assert returned_usage == usage + + def test_react_invoke_returns_none_for_finish(self) -> None: + router = ReactMultiDatasetRouter() + model_config = Mock() + model_config.mode = "completion" + model_config.parameters = {"temperature": 0.1} + usage = LLMUsage.empty_usage() + tool = Mock() + tool.name = "dataset-1" + tool.description = "desc" + + with ( + patch.object(router, "create_completion_prompt", return_value=Mock()) as mock_completion_prompt, + patch( + "core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform" + ) as mock_prompt_transform, + patch.object( + router, "_invoke_llm", return_value=('{"action":"Final Answer","action_input":"done"}', usage) + ), + patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls, + ): + mock_prompt_transform.return_value.get_prompt.return_value = [Mock()] + mock_parser_cls.return_value.parse.return_value = ReactFinish({"output": "done"}, "log") + + dataset_id, returned_usage = router._react_invoke( + query="python", + model_config=model_config, + model_instance=Mock(), + tools=[tool], + user_id="u1", + tenant_id="t1", + ) + + mock_completion_prompt.assert_called_once() + assert dataset_id is None + assert returned_usage == usage + + def test_invoke_llm_and_handle_result(self) -> None: + router = ReactMultiDatasetRouter() + usage = LLMUsage.empty_usage() + delta = SimpleNamespace(message=SimpleNamespace(content="part"), usage=usage) + chunk = SimpleNamespace(model="m1", prompt_messages=[Mock()], delta=delta) + model_instance = Mock() + model_instance.invoke_llm.return_value = iter([chunk]) + + with patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct: + text, returned_usage = router._invoke_llm( + completion_param={"temperature": 0.1}, + model_instance=model_instance, + prompt_messages=[Mock()], + stop=["Observation:"], + user_id="u1", + tenant_id="t1", + ) + + assert text == "part" + assert returned_usage == usage + mock_deduct.assert_called_once() + + def test_handle_invoke_result_with_empty_usage(self) -> None: + router = ReactMultiDatasetRouter() + delta = SimpleNamespace(message=SimpleNamespace(content="part"), usage=None) + chunk = SimpleNamespace(model="m1", prompt_messages=[Mock()], delta=delta) + + text, usage = router._handle_invoke_result(iter([chunk])) + + assert text == "part" + assert usage == LLMUsage.empty_usage() + + def test_create_chat_prompt(self) -> None: + router = ReactMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_1.description = "d1" + tool_2 = Mock() + tool_2.name = "dataset-2" + tool_2.description = "d2" + + chat_prompt = router.create_chat_prompt(query="python", tools=[tool_1, tool_2]) + assert len(chat_prompt) == 2 + assert chat_prompt[0].role == PromptMessageRole.SYSTEM + assert chat_prompt[1].role == PromptMessageRole.USER + assert "dataset-1" in chat_prompt[0].text + assert "dataset-2" in chat_prompt[0].text + + def test_create_completion_prompt(self) -> None: + router = ReactMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_1.description = "d1" + tool_2 = Mock() + tool_2.name = "dataset-2" + tool_2.description = "d2" + + completion_prompt = router.create_completion_prompt(tools=[tool_1, tool_2]) + assert "dataset-1: d1" in completion_prompt.text + assert "dataset-2: d2" in completion_prompt.text + + def test_react_invoke_uses_completion_branch_for_non_chat_mode(self) -> None: + router = ReactMultiDatasetRouter() + model_config = Mock() + model_config.mode = "unknown-mode" + model_config.parameters = {} + tool = Mock() + tool.name = "dataset-1" + tool.description = "desc" + + with ( + patch.object(router, "create_completion_prompt", return_value=Mock()) as mock_completion_prompt, + patch( + "core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform" + ) as mock_prompt_transform, + patch.object( + router, + "_invoke_llm", + return_value=('{"action":"Final Answer","action_input":"done"}', LLMUsage.empty_usage()), + ), + patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls, + ): + mock_prompt_transform.return_value.get_prompt.return_value = [Mock()] + mock_parser_cls.return_value.parse.return_value = ReactFinish({"output": "done"}, "log") + dataset_id, usage = router._react_invoke( + query="python", + model_config=model_config, + model_instance=Mock(), + tools=[tool], + user_id="u1", + tenant_id="t1", + ) + + mock_completion_prompt.assert_called_once() + assert dataset_id is None + assert usage == LLMUsage.empty_usage() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_structured_chat_output_parser.py b/api/tests/unit_tests/core/rag/retrieval/test_structured_chat_output_parser.py new file mode 100644 index 0000000000..c8fa0ea62f --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_structured_chat_output_parser.py @@ -0,0 +1,69 @@ +import pytest + +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser + + +class TestStructuredChatOutputParser: + def test_parse_action_without_action_input(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n{"action":"some_action"}\n```' + result = parser.parse(text) + + assert isinstance(result, ReactAction) + assert result.tool == "some_action" + assert result.tool_input == {} + + def test_parse_json_without_action_key(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n{"not_action":"search"}\n```' + with pytest.raises(ValueError, match="Could not parse LLM output"): + parser.parse(text) + + def test_parse_returns_action_for_tool_call(self) -> None: + parser = StructuredChatOutputParser() + text = ( + 'Thought: call tool\nAction:\n```json\n{"action":"search_dataset","action_input":{"query":"python"}}\n```' + ) + + result = parser.parse(text) + + assert isinstance(result, ReactAction) + assert result.tool == "search_dataset" + assert result.tool_input == {"query": "python"} + assert result.log == text + + def test_parse_returns_finish_for_final_answer(self) -> None: + parser = StructuredChatOutputParser() + text = 'Thought: done\nAction:\n```json\n{"action":"Final Answer","action_input":"final text"}\n```' + + result = parser.parse(text) + + assert isinstance(result, ReactFinish) + assert result.return_values == {"output": "final text"} + assert result.log == text + + def test_parse_returns_finish_for_json_array_payload(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n[{"action":"search","action_input":"hello"}]\n```' + result = parser.parse(text) + + assert isinstance(result, ReactFinish) + assert result.return_values == {"output": text} + assert result.log == text + + def test_parse_returns_finish_for_plain_text(self) -> None: + parser = StructuredChatOutputParser() + text = "No structured action block" + + result = parser.parse(text) + + assert isinstance(result, ReactFinish) + assert result.return_values == {"output": text} + + def test_parse_raises_value_error_for_invalid_json(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n{"action":"search","action_input": }\n```' + + with pytest.raises(ValueError, match="Could not parse LLM output"): + parser.parse(text) diff --git a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py index 943a9e5712..976de10d89 100644 --- a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py +++ b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py @@ -125,7 +125,11 @@ Run with coverage: - Tests are organized by functionality in classes for better organization """ +import asyncio import string +import sys +import types +from inspect import currentframe from unittest.mock import Mock, patch import pytest @@ -604,6 +608,51 @@ class TestRecursiveCharacterTextSplitter: assert "def hello_world" in combined or "hello_world" in combined +class TestTextSplitterBasePaths: + """Target uncovered base TextSplitter paths.""" + + def test_from_huggingface_tokenizer_success_path(self): + """Cover from_huggingface_tokenizer success branch with mocked transformers.""" + + class _FakePreTrainedTokenizerBase: + pass + + class _FakeTokenizer(_FakePreTrainedTokenizerBase): + def encode(self, text: str): + return [ord(c) for c in text] + + fake_transformers = types.SimpleNamespace(PreTrainedTokenizerBase=_FakePreTrainedTokenizerBase) + with patch.dict(sys.modules, {"transformers": fake_transformers}): + splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( + tokenizer=_FakeTokenizer(), + chunk_size=5, + chunk_overlap=1, + ) + + chunks = splitter.split_text("abcdef") + assert chunks + + def test_from_huggingface_tokenizer_import_error(self): + """Cover from_huggingface_tokenizer import-error branch.""" + with patch.dict(sys.modules, {"transformers": None}): + with pytest.raises(ValueError, match="Could not import transformers"): + RecursiveCharacterTextSplitter.from_huggingface_tokenizer(tokenizer=object(), chunk_size=5) + + def test_atransform_documents_raises_not_implemented(self): + """Cover atransform_documents NotImplemented branch.""" + splitter = RecursiveCharacterTextSplitter(chunk_size=20, chunk_overlap=5) + with pytest.raises(NotImplementedError): + asyncio.run(splitter.atransform_documents([Document(page_content="x", metadata={})])) + + def test_merge_splits_logs_warning_for_oversized_total(self): + """Cover logger.warning path in _merge_splits.""" + splitter = RecursiveCharacterTextSplitter(chunk_size=5, chunk_overlap=1) + with patch("core.rag.splitter.text_splitter.logger.warning") as mock_warning: + merged = splitter._merge_splits(["abcdefghij", "b"], "", [10, 1]) + assert merged + mock_warning.assert_called_once() + + # ============================================================================ # Test TokenTextSplitter # ============================================================================ @@ -662,6 +711,44 @@ class TestTokenTextSplitter: except ImportError: pytest.skip("tiktoken not installed") + def test_initialization_and_split_with_mocked_tiktoken_encoding(self): + """Cover TokenTextSplitter __init__ else-path and split_text logic.""" + + class _FakeEncoding: + def encode(self, text: str, allowed_special=None, disallowed_special=None): + return [ord(c) for c in text] + + def decode(self, token_ids: list[int]) -> str: + return "".join(chr(i) for i in token_ids) + + fake_tiktoken = types.SimpleNamespace(get_encoding=lambda name: _FakeEncoding()) + with patch.dict(sys.modules, {"tiktoken": fake_tiktoken}): + splitter = TokenTextSplitter(encoding_name="gpt2", chunk_size=4, chunk_overlap=1) + result = splitter.split_text("abcdefgh") + + assert result + assert all(isinstance(chunk, str) for chunk in result) + + def test_initialization_with_model_name_uses_encoding_for_model(self): + """Cover TokenTextSplitter model_name init branch.""" + + class _FakeEncoding: + def encode(self, text: str, allowed_special=None, disallowed_special=None): + return [ord(c) for c in text] + + def decode(self, token_ids: list[int]) -> str: + return "".join(chr(i) for i in token_ids) + + fake_encoding = _FakeEncoding() + fake_tiktoken = types.SimpleNamespace( + encoding_for_model=lambda model_name: fake_encoding, + get_encoding=lambda name: _FakeEncoding(), + ) + with patch.dict(sys.modules, {"tiktoken": fake_tiktoken}): + splitter = TokenTextSplitter(model_name="gpt-4", chunk_size=5, chunk_overlap=1) + + assert splitter._tokenizer is fake_encoding + # ============================================================================ # Test EnhanceRecursiveCharacterTextSplitter @@ -731,6 +818,50 @@ class TestEnhanceRecursiveCharacterTextSplitter: assert len(result) > 0 assert all(isinstance(chunk, str) for chunk in result) + def test_from_encoder_internal_token_encoder_paths(self): + """ + Test internal _token_encoder branches by capturing local closure from frame. + + This validates: + - empty texts path + - embedding model path + - GPT2Tokenizer fallback path + - _character_encoder empty-path branch + """ + + class _SpySplitter(EnhanceRecursiveCharacterTextSplitter): + captured_token_encoder = None + captured_character_encoder = None + + def __init__(self, **kwargs): + frame = currentframe() + if frame and frame.f_back: + _SpySplitter.captured_token_encoder = frame.f_back.f_locals.get("_token_encoder") + _SpySplitter.captured_character_encoder = frame.f_back.f_locals.get("_character_encoder") + super().__init__(**kwargs) + + mock_model = Mock() + mock_model.get_text_embedding_num_tokens.return_value = [3, 5] + + _SpySplitter.from_encoder(embedding_model_instance=mock_model, chunk_size=10, chunk_overlap=1) + token_encoder = _SpySplitter.captured_token_encoder + character_encoder = _SpySplitter.captured_character_encoder + + assert token_encoder is not None + assert character_encoder is not None + assert token_encoder([]) == [] + assert token_encoder(["abc", "defgh"]) == [3, 5] + assert character_encoder([]) == [] + + with patch( + "core.rag.splitter.fixed_text_splitter.GPT2Tokenizer.get_num_tokens", + side_effect=lambda text: len(text) + 1, + ): + _SpySplitter.from_encoder(embedding_model_instance=None, chunk_size=10, chunk_overlap=1) + token_encoder_without_model = _SpySplitter.captured_token_encoder + assert token_encoder_without_model is not None + assert token_encoder_without_model(["ab", "cdef"]) == [3, 5] + # ============================================================================ # Test FixedRecursiveCharacterTextSplitter @@ -908,6 +1039,56 @@ class TestFixedRecursiveCharacterTextSplitter: chunks = splitter.split_text(data) assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."] + def test_recursive_split_keep_separator_and_recursive_fallback(self): + """Cover keep-separator split branch and recursive _split_text fallback.""" + text = "short." + ("x" * 60) + splitter = FixedRecursiveCharacterTextSplitter( + fixed_separator="", + separators=[".", " ", ""], + chunk_size=10, + chunk_overlap=2, + keep_separator=True, + ) + + chunks = splitter.recursive_split_text(text) + + assert chunks + assert any("short." in chunk for chunk in chunks) + assert any(len(chunk) <= 12 for chunk in chunks) + + def test_recursive_split_newline_separator_filtering(self): + """Cover newline-specific empty filtering branch.""" + text = "line1\n\nline2\n\nline3" + splitter = FixedRecursiveCharacterTextSplitter( + fixed_separator="", + separators=["\n", ""], + chunk_size=50, + chunk_overlap=5, + ) + + chunks = splitter.recursive_split_text(text) + + assert chunks + assert all(chunk != "" for chunk in chunks) + assert "line1" in "".join(chunks) + assert "line2" in "".join(chunks) + assert "line3" in "".join(chunks) + + def test_recursive_split_without_new_separator_appends_long_chunk(self): + """Cover branch where no further separators exist and long split is appended directly.""" + text = "aa\n" + ("b" * 40) + splitter = FixedRecursiveCharacterTextSplitter( + fixed_separator="", + separators=["\n"], + chunk_size=10, + chunk_overlap=2, + ) + + chunks = splitter.recursive_split_text(text) + + assert "aa" in chunks + assert any(len(chunk) >= 40 for chunk in chunks) + # ============================================================================ # Test Metadata Preservation diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 7bf7c2e5f6..9af4d12664 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -5,7 +5,6 @@ from __future__ import annotations import dataclasses from datetime import datetime from types import SimpleNamespace -from unittest.mock import MagicMock import pytest @@ -35,7 +34,7 @@ from models.human_input import ( def _build_repository() -> HumanInputFormRepositoryImpl: - return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id") + return HumanInputFormRepositoryImpl(tenant_id="tenant-id") def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]: @@ -389,8 +388,21 @@ def _session_factory(session: _FakeSession): return _factory +def _patch_repo_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None: + """Patch repository's global session factory to return our fake session. + + The repositories under test now use a global session factory; patch its + create_session method so unit tests don't hit a real database. + """ + monkeypatch.setattr( + "core.repositories.human_input_repository.session_factory.create_session", + _session_factory(session), + raising=True, + ) + + class TestHumanInputFormRepositoryImplPublicMethods: - def test_get_form_returns_entity_and_recipients(self): + def test_get_form_returns_entity_and_recipients(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -408,7 +420,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: access_token="token-123", ) session = _FakeSession(scalars_results=[form, [recipient]]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -418,13 +431,14 @@ class TestHumanInputFormRepositoryImplPublicMethods: assert len(entity.recipients) == 1 assert entity.recipients[0].token == "token-123" - def test_get_form_returns_none_when_missing(self): + def test_get_form_returns_none_when_missing(self, monkeypatch: pytest.MonkeyPatch): session = _FakeSession(scalars_results=[None]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") assert repo.get_form("run-1", "node-1") is None - def test_get_form_returns_unsubmitted_state(self): + def test_get_form_returns_unsubmitted_state(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -436,7 +450,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: expiration_time=naive_utc_now(), ) session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -445,7 +460,7 @@ class TestHumanInputFormRepositoryImplPublicMethods: assert entity.selected_action_id is None assert entity.submitted_data is None - def test_get_form_returns_submission_when_completed(self): + def test_get_form_returns_submission_when_completed(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -460,7 +475,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: submitted_at=naive_utc_now(), ) session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -471,7 +487,7 @@ class TestHumanInputFormRepositoryImplPublicMethods: class TestHumanInputFormSubmissionRepository: - def test_get_by_token_returns_record(self): + def test_get_by_token_returns_record(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -490,7 +506,8 @@ class TestHumanInputFormSubmissionRepository: form=form, ) session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record = repo.get_by_token("token-123") @@ -499,7 +516,7 @@ class TestHumanInputFormSubmissionRepository: assert record.recipient_type == RecipientType.STANDALONE_WEB_APP assert record.submitted is False - def test_get_by_form_id_and_recipient_type_uses_recipient(self): + def test_get_by_form_id_and_recipient_type_uses_recipient(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -518,7 +535,8 @@ class TestHumanInputFormSubmissionRepository: form=form, ) session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record = repo.get_by_form_id_and_recipient_type( form_id=form.id, @@ -553,7 +571,8 @@ class TestHumanInputFormSubmissionRepository: forms={form.id: form}, recipients={recipient.id: recipient}, ) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record: HumanInputFormRecord = repo.mark_submitted( form_id=form.id, diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py new file mode 100644 index 0000000000..c66e50437a --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -0,0 +1,84 @@ +from datetime import datetime +from unittest.mock import MagicMock +from uuid import uuid4 + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType +from models import Account, WorkflowRun +from models.enums import WorkflowRunTriggeredFrom + + +def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository: + engine = create_engine("sqlite:///:memory:") + real_session_factory = sessionmaker(bind=engine, expire_on_commit=False) + + user = MagicMock(spec=Account) + user.id = str(uuid4()) + user.current_tenant_id = str(uuid4()) + + repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=real_session_factory, + user=user, + app_id="app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + session_context = MagicMock() + session_context.__enter__.return_value = session + session_context.__exit__.return_value = False + repository._session_factory = MagicMock(return_value=session_context) + return repository + + +def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution: + return WorkflowExecution.new( + id_=execution_id, + workflow_id="workflow-id", + workflow_type=WorkflowType.WORKFLOW, + workflow_version="1.0.0", + graph={"nodes": [], "edges": []}, + inputs={"query": "hello"}, + started_at=started_at, + ) + + +def test_save_uses_execution_started_at_when_record_does_not_exist(): + session = MagicMock() + session.get.return_value = None + repository = _build_repository_with_mocked_session(session) + + started_at = datetime(2026, 1, 1, 12, 0, 0) + execution = _build_execution(execution_id=str(uuid4()), started_at=started_at) + + repository.save(execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == started_at + session.commit.assert_called_once() + + +def test_save_preserves_existing_created_at_when_record_already_exists(): + session = MagicMock() + repository = _build_repository_with_mocked_session(session) + + execution_id = str(uuid4()) + existing_created_at = datetime(2026, 1, 1, 12, 0, 0) + existing_run = WorkflowRun() + existing_run.id = execution_id + existing_run.tenant_id = repository._tenant_id + existing_run.created_at = existing_created_at + session.get.return_value = existing_run + + execution = _build_execution( + execution_id=execution_id, + started_at=datetime(2026, 1, 1, 12, 30, 0), + ) + + repository.save(execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == existing_created_at + session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/trigger/__init__.py b/api/tests/unit_tests/core/trigger/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/trigger/conftest.py b/api/tests/unit_tests/core/trigger/conftest.py new file mode 100644 index 0000000000..d9da80a8b7 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/conftest.py @@ -0,0 +1,93 @@ +"""Shared factory helpers for core.trigger test suite.""" + +from __future__ import annotations + +from typing import Any + +from core.entities.provider_entities import ProviderConfig +from core.tools.entities.common_entities import I18nObject +from core.trigger.entities.entities import ( + EventEntity, + EventIdentity, + EventParameter, + OAuthSchema, + Subscription, + SubscriptionConstructor, + TriggerProviderEntity, + TriggerProviderIdentity, +) +from core.trigger.provider import PluginTriggerProviderController +from models.provider_ids import TriggerProviderID + +# Valid format for TriggerProviderID: org/plugin/provider +VALID_PROVIDER_ID = "testorg/testplugin/testprovider" + + +def i18n(text: str = "test") -> I18nObject: + return I18nObject(en_US=text, zh_Hans=text) + + +def make_event(name: str = "test_event", parameters: list[EventParameter] | None = None) -> EventEntity: + return EventEntity( + identity=EventIdentity(author="a", name=name, label=i18n(name)), + description=i18n(name), + parameters=parameters or [], + ) + + +def make_provider_entity( + name: str = "test_provider", + events: list[EventEntity] | None = None, + constructor: SubscriptionConstructor | None = None, + subscription_schema: list[ProviderConfig] | None = None, + icon: str | None = "icon.png", + icon_dark: str | None = None, +) -> TriggerProviderEntity: + return TriggerProviderEntity( + identity=TriggerProviderIdentity( + author="a", + name=name, + label=i18n(name), + description=i18n(name), + icon=icon, + icon_dark=icon_dark, + ), + events=events if events is not None else [make_event()], + subscription_constructor=constructor, + subscription_schema=subscription_schema or [], + ) + + +def make_controller( + entity: TriggerProviderEntity | None = None, + tenant_id: str = "tenant-1", + provider_id: str = VALID_PROVIDER_ID, +) -> PluginTriggerProviderController: + return PluginTriggerProviderController( + entity=entity or make_provider_entity(), + plugin_id="plugin-1", + plugin_unique_identifier="uid-1", + provider_id=TriggerProviderID(provider_id), + tenant_id=tenant_id, + ) + + +def make_subscription(**overrides: Any) -> Subscription: + defaults = {"expires_at": 9999999999, "endpoint": "https://hook.test", "properties": {"k": "v"}, "parameters": {}} + defaults.update(overrides) + return Subscription(**defaults) + + +def make_provider_config( + name: str = "api_key", required: bool = True, config_type: str = "secret-input" +) -> ProviderConfig: + return ProviderConfig(name=name, label=i18n(name), type=config_type, required=required) + + +def make_constructor( + credentials_schema: list[ProviderConfig] | None = None, + oauth_schema: OAuthSchema | None = None, +) -> SubscriptionConstructor: + return SubscriptionConstructor( + parameters=[], credentials_schema=credentials_schema or [], oauth_schema=oauth_schema + ) diff --git a/api/tests/unit_tests/core/trigger/debug/__init__.py b/api/tests/unit_tests/core/trigger/debug/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_bus.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_bus.py new file mode 100644 index 0000000000..d557c20f5e --- /dev/null +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_bus.py @@ -0,0 +1,93 @@ +""" +Tests for core.trigger.debug.event_bus.TriggerDebugEventBus. + +Covers: Lua-script dispatch/poll with Redis error resilience. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from redis import RedisError + +from core.trigger.debug.event_bus import TriggerDebugEventBus +from core.trigger.debug.events import PluginTriggerDebugEvent + + +class TestDispatch: + @patch("core.trigger.debug.event_bus.redis_client") + def test_returns_dispatch_count(self, mock_redis): + mock_redis.eval.return_value = 3 + event = MagicMock() + event.model_dump_json.return_value = '{"test": true}' + + result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key") + + assert result == 3 + mock_redis.eval.assert_called_once() + + @patch("core.trigger.debug.event_bus.redis_client") + def test_redis_error_returns_zero(self, mock_redis): + mock_redis.eval.side_effect = RedisError("connection lost") + event = MagicMock() + event.model_dump_json.return_value = "{}" + + result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key") + + assert result == 0 + + +class TestPoll: + @patch("core.trigger.debug.event_bus.redis_client") + def test_returns_deserialized_event(self, mock_redis): + event_json = PluginTriggerDebugEvent( + timestamp=100, + name="push", + user_id="u1", + request_id="r1", + subscription_id="s1", + provider_id="p1", + ).model_dump_json() + mock_redis.eval.return_value = event_json + + result = TriggerDebugEventBus.poll( + event_type=PluginTriggerDebugEvent, + pool_key="pool:key", + tenant_id="t1", + user_id="u1", + app_id="a1", + node_id="n1", + ) + + assert result is not None + assert result.name == "push" + + @patch("core.trigger.debug.event_bus.redis_client") + def test_returns_none_when_no_event(self, mock_redis): + mock_redis.eval.return_value = None + + result = TriggerDebugEventBus.poll( + event_type=PluginTriggerDebugEvent, + pool_key="pool:key", + tenant_id="t1", + user_id="u1", + app_id="a1", + node_id="n1", + ) + + assert result is None + + @patch("core.trigger.debug.event_bus.redis_client") + def test_redis_error_returns_none(self, mock_redis): + mock_redis.eval.side_effect = RedisError("timeout") + + result = TriggerDebugEventBus.poll( + event_type=PluginTriggerDebugEvent, + pool_key="pool:key", + tenant_id="t1", + user_id="u1", + app_id="a1", + node_id="n1", + ) + + assert result is None diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py new file mode 100644 index 0000000000..331bcd6c25 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -0,0 +1,276 @@ +""" +Tests for core.trigger.debug.event_selectors. + +Covers: Plugin/Webhook/Schedule pollers, create_event_poller factory, +and select_trigger_debug_events orchestrator. +""" + +from __future__ import annotations + +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.request import TriggerInvokeEventResponse +from core.trigger.debug.event_selectors import ( + PluginTriggerDebugEventPoller, + ScheduleTriggerDebugEventPoller, + WebhookTriggerDebugEventPoller, + create_event_poller, + select_trigger_debug_events, +) +from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent +from dify_graph.enums import NodeType +from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID + + +def _make_poller_args(node_config: dict | None = None) -> dict: + return { + "tenant_id": "t1", + "user_id": "u1", + "app_id": "a1", + "node_config": node_config or {"data": {}}, + "node_id": "n1", + } + + +def _plugin_node_config(provider_id: str = VALID_PROVIDER_ID) -> dict: + """Valid node config for TriggerEventNodeData.model_validate.""" + return { + "data": { + "title": "test", + "plugin_id": "org/testplugin", + "provider_id": provider_id, + "event_name": "push", + "subscription_id": "s1", + "plugin_unique_identifier": "uid-1", + } + } + + +class TestPluginTriggerDebugEventPoller: + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_workflow_args_on_success(self, mock_bus): + event = PluginTriggerDebugEvent( + timestamp=100, + name="push", + user_id="u1", + request_id="r1", + subscription_id="s1", + provider_id="p1", + ) + mock_bus.poll.return_value = event + + with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc: + mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse( + variables={"repo": "dify"}, + cancelled=False, + ) + + poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config())) + result = poller.poll() + + assert result is not None + assert result.workflow_args["inputs"] == {"repo": "dify"} + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_none_when_no_event(self, mock_bus): + mock_bus.poll.return_value = None + + poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config())) + + assert poller.poll() is None + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_none_when_invoke_cancelled(self, mock_bus): + event = PluginTriggerDebugEvent( + timestamp=100, + name="push", + user_id="u1", + request_id="r1", + subscription_id="s1", + provider_id="p1", + ) + mock_bus.poll.return_value = event + + with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc: + mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse( + variables={}, + cancelled=True, + ) + + poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config())) + + assert poller.poll() is None + + +class TestWebhookTriggerDebugEventPoller: + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_uses_inputs_directly_when_present(self, mock_bus): + event = WebhookDebugEvent( + timestamp=100, + request_id="r1", + node_id="n1", + payload={"inputs": {"key": "val"}, "webhook_data": {}}, + ) + mock_bus.poll.return_value = event + + poller = WebhookTriggerDebugEventPoller(**_make_poller_args()) + result = poller.poll() + + assert result is not None + assert result.workflow_args["inputs"] == {"key": "val"} + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_falls_back_to_webhook_data(self, mock_bus): + event = WebhookDebugEvent( + timestamp=100, + request_id="r1", + node_id="n1", + payload={"webhook_data": {"body": "raw"}}, + ) + mock_bus.poll.return_value = event + + with patch("services.trigger.webhook_service.WebhookService") as mock_webhook_svc: + mock_webhook_svc.build_workflow_inputs.return_value = {"parsed": "data"} + + poller = WebhookTriggerDebugEventPoller(**_make_poller_args()) + result = poller.poll() + + assert result is not None + assert result.workflow_args["inputs"] == {"parsed": "data"} + mock_webhook_svc.build_workflow_inputs.assert_called_once_with({"body": "raw"}) + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_none_when_no_event(self, mock_bus): + mock_bus.poll.return_value = None + poller = WebhookTriggerDebugEventPoller(**_make_poller_args()) + + assert poller.poll() is None + + +class TestScheduleTriggerDebugEventPoller: + def _make_schedule_poller(self, mock_redis, mock_schedule_svc, next_run_at: datetime): + """Set up mocks and create a schedule poller.""" + mock_redis.get.return_value = None + mock_schedule_config = MagicMock() + mock_schedule_config.cron_expression = "0 * * * *" + mock_schedule_config.timezone = "UTC" + mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config + return ScheduleTriggerDebugEventPoller(**_make_poller_args()) + + @patch("core.trigger.debug.event_selectors.redis_client") + @patch("core.trigger.debug.event_selectors.naive_utc_now") + @patch("core.trigger.debug.event_selectors.calculate_next_run_at") + @patch("core.trigger.debug.event_selectors.ensure_naive_utc") + def test_returns_none_when_not_yet_due(self, mock_ensure, mock_calc, mock_now, mock_redis): + now = datetime(2025, 1, 1, 12, 0, 0) + next_run = datetime(2025, 1, 1, 13, 0, 0) # future + mock_now.return_value = now + mock_calc.return_value = next_run + mock_ensure.return_value = next_run + mock_redis.get.return_value = None + + with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc: + mock_schedule_config = MagicMock() + mock_schedule_config.cron_expression = "0 * * * *" + mock_schedule_config.timezone = "UTC" + mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config + + poller = ScheduleTriggerDebugEventPoller(**_make_poller_args()) + + assert poller.poll() is None + + @patch("core.trigger.debug.event_selectors.redis_client") + @patch("core.trigger.debug.event_selectors.naive_utc_now") + @patch("core.trigger.debug.event_selectors.calculate_next_run_at") + @patch("core.trigger.debug.event_selectors.ensure_naive_utc") + def test_fires_event_when_due(self, mock_ensure, mock_calc, mock_now, mock_redis): + now = datetime(2025, 1, 1, 14, 0, 0) + next_run = datetime(2025, 1, 1, 12, 0, 0) # past + mock_now.return_value = now + mock_calc.return_value = next_run + mock_ensure.return_value = next_run + mock_redis.get.return_value = None + + with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc: + mock_schedule_config = MagicMock() + mock_schedule_config.cron_expression = "0 * * * *" + mock_schedule_config.timezone = "UTC" + mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config + + poller = ScheduleTriggerDebugEventPoller(**_make_poller_args()) + result = poller.poll() + + assert result is not None + mock_redis.delete.assert_called_once() + + +class TestCreateEventPoller: + def _workflow_with_node(self, node_type: NodeType): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = node_type + return wf + + def test_creates_plugin_poller(self): + wf = self._workflow_with_node(NodeType.TRIGGER_PLUGIN) + poller = create_event_poller(wf, "t1", "u1", "a1", "n1") + assert isinstance(poller, PluginTriggerDebugEventPoller) + + def test_creates_webhook_poller(self): + wf = self._workflow_with_node(NodeType.TRIGGER_WEBHOOK) + poller = create_event_poller(wf, "t1", "u1", "a1", "n1") + assert isinstance(poller, WebhookTriggerDebugEventPoller) + + def test_creates_schedule_poller(self): + wf = self._workflow_with_node(NodeType.TRIGGER_SCHEDULE) + poller = create_event_poller(wf, "t1", "u1", "a1", "n1") + assert isinstance(poller, ScheduleTriggerDebugEventPoller) + + def test_raises_for_unknown_type(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = NodeType.START + + with pytest.raises(ValueError): + create_event_poller(wf, "t1", "u1", "a1", "n1") + + def test_raises_when_node_config_missing(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = None + + with pytest.raises(ValueError): + create_event_poller(wf, "t1", "u1", "a1", "n1") + + +class TestSelectTriggerDebugEvents: + def test_returns_first_non_none_event(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK + app_model = MagicMock() + app_model.tenant_id = "t1" + app_model.id = "a1" + + with patch.object(WebhookTriggerDebugEventPoller, "poll") as mock_poll: + expected = MagicMock() + mock_poll.return_value = expected + + result = select_trigger_debug_events(wf, app_model, "u1", ["n1", "n2"]) + + assert result is expected + + def test_returns_none_when_no_events(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK + app_model = MagicMock() + app_model.tenant_id = "t1" + app_model.id = "a1" + + with patch.object(WebhookTriggerDebugEventPoller, "poll", return_value=None): + result = select_trigger_debug_events(wf, app_model, "u1", ["n1"]) + + assert result is None diff --git a/api/tests/unit_tests/core/trigger/test_provider.py b/api/tests/unit_tests/core/trigger/test_provider.py new file mode 100644 index 0000000000..3c2f297e90 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/test_provider.py @@ -0,0 +1,332 @@ +""" +Tests for core.trigger.provider.PluginTriggerProviderController. + +Covers: to_api_entity creation-method logic, credential validation pipeline, +schema resolution by type, event lookup, dispatch/invoke/subscribe delegation. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import CredentialType +from core.trigger.entities.entities import ( + EventParameter, + EventParameterType, + OAuthSchema, + TriggerCreationMethod, +) +from core.trigger.errors import TriggerProviderCredentialValidationError +from tests.unit_tests.core.trigger.conftest import ( + i18n, + make_constructor, + make_controller, + make_event, + make_provider_config, + make_provider_entity, + make_subscription, +) + +ICON_URL = "https://cdn/icon.png" + + +class TestToApiEntity: + @patch("core.trigger.provider.PluginService") + def test_includes_icons_when_present(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + ctrl = make_controller(entity=make_provider_entity(icon="icon.png", icon_dark="dark.png")) + + api = ctrl.to_api_entity() + + assert api.icon == ICON_URL + assert api.icon_dark == ICON_URL + + @patch("core.trigger.provider.PluginService") + def test_icons_none_when_absent(self, mock_plugin_svc): + ctrl = make_controller(entity=make_provider_entity(icon=None, icon_dark=None)) + + api = ctrl.to_api_entity() + + assert api.icon is None + assert api.icon_dark is None + mock_plugin_svc.get_plugin_icon_url.assert_not_called() + + @patch("core.trigger.provider.PluginService") + def test_manual_only_without_schemas(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + + api = ctrl.to_api_entity() + + assert api.supported_creation_methods == [TriggerCreationMethod.MANUAL] + + @patch("core.trigger.provider.PluginService") + def test_adds_oauth_when_oauth_schema_present(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + oauth = OAuthSchema(client_schema=[], credentials_schema=[]) + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth))) + + api = ctrl.to_api_entity() + + assert TriggerCreationMethod.OAUTH in api.supported_creation_methods + assert TriggerCreationMethod.MANUAL in api.supported_creation_methods + + @patch("core.trigger.provider.PluginService") + def test_adds_apikey_when_credentials_schema_present(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()])) + ) + + api = ctrl.to_api_entity() + + assert TriggerCreationMethod.APIKEY in api.supported_creation_methods + + +class TestGetEvent: + def test_returns_matching_event(self): + evt = make_event("push") + ctrl = make_controller(entity=make_provider_entity(events=[evt, make_event("pr")])) + + assert ctrl.get_event("push") is evt + + def test_returns_none_for_unknown(self): + ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")])) + + assert ctrl.get_event("nonexistent") is None + + +class TestGetSubscriptionDefaultProperties: + def test_returns_defaults_skipping_none(self): + config1 = make_provider_config("key1") + config1.default = "val1" + config2 = make_provider_config("key2") + config2.default = None + ctrl = make_controller(entity=make_provider_entity(subscription_schema=[config1, config2])) + + props = ctrl.get_subscription_default_properties() + + assert props == {"key1": "val1"} + + +class TestValidateCredentials: + def test_raises_when_no_constructor(self): + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + + with pytest.raises(ValueError, match="Subscription constructor not found"): + ctrl.validate_credentials("u1", {"key": "val"}) + + def test_raises_for_missing_required_field(self): + required_cfg = make_provider_config("api_key", required=True) + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg])) + ) + + with pytest.raises(TriggerProviderCredentialValidationError, match="Missing required"): + ctrl.validate_credentials("u1", {}) + + @patch("core.trigger.provider.PluginTriggerClient") + def test_passes_with_valid_credentials(self, mock_client): + required_cfg = make_provider_config("api_key", required=True) + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg])) + ) + mock_client.return_value.validate_provider_credentials.return_value = True + + ctrl.validate_credentials("u1", {"api_key": "secret123"}) # should not raise + + @patch("core.trigger.provider.PluginTriggerClient") + def test_raises_when_plugin_rejects(self, mock_client): + required_cfg = make_provider_config("api_key", required=True) + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg])) + ) + mock_client.return_value.validate_provider_credentials.return_value = None + + with pytest.raises(TriggerProviderCredentialValidationError, match="Invalid credentials"): + ctrl.validate_credentials("u1", {"api_key": "bad"}) + + +class TestGetSupportedCredentialTypes: + def test_empty_when_no_constructor(self): + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + assert ctrl.get_supported_credential_types() == [] + + def test_oauth_only(self): + oauth = OAuthSchema(client_schema=[], credentials_schema=[]) + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth))) + + types = ctrl.get_supported_credential_types() + + assert CredentialType.OAUTH2 in types + assert CredentialType.API_KEY not in types + + def test_apikey_only(self): + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()])) + ) + + types = ctrl.get_supported_credential_types() + + assert CredentialType.API_KEY in types + assert CredentialType.OAUTH2 not in types + + def test_both(self): + oauth = OAuthSchema(client_schema=[], credentials_schema=[make_provider_config("oauth_secret")]) + ctrl = make_controller( + entity=make_provider_entity( + constructor=make_constructor(credentials_schema=[make_provider_config()], oauth_schema=oauth) + ) + ) + + types = ctrl.get_supported_credential_types() + + assert CredentialType.OAUTH2 in types + assert CredentialType.API_KEY in types + + +class TestGetCredentialsSchema: + def test_returns_empty_when_no_constructor(self): + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + assert ctrl.get_credentials_schema(CredentialType.API_KEY) == [] + + def test_returns_apikey_credentials(self): + cfg = make_provider_config("token") + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(credentials_schema=[cfg]))) + + result = ctrl.get_credentials_schema(CredentialType.API_KEY) + + assert len(result) == 1 + assert result[0].name == "token" + + def test_returns_oauth_credentials(self): + oauth_cred = make_provider_config("oauth_token") + oauth = OAuthSchema(client_schema=[], credentials_schema=[oauth_cred]) + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth))) + + result = ctrl.get_credentials_schema(CredentialType.OAUTH2) + + assert len(result) == 1 + assert result[0].name == "oauth_token" + + def test_unauthorized_returns_empty(self): + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()])) + ) + assert ctrl.get_credentials_schema(CredentialType.UNAUTHORIZED) == [] + + def test_invalid_type_raises(self): + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor())) + with pytest.raises(ValueError, match="Invalid credential type"): + ctrl.get_credentials_schema("bogus_type") + + +class TestGetEventParameters: + def test_returns_params_for_known_event(self): + param = EventParameter(name="branch", label=i18n("branch"), type=EventParameterType.STRING) + evt = make_event("push", parameters=[param]) + ctrl = make_controller(entity=make_provider_entity(events=[evt])) + + result = ctrl.get_event_parameters("push") + + assert "branch" in result + assert result["branch"].name == "branch" + + def test_returns_empty_for_unknown_event(self): + ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")])) + + assert ctrl.get_event_parameters("nonexistent") == {} + + +class TestDispatch: + @patch("core.trigger.provider.PluginTriggerClient") + def test_delegates_to_client(self, mock_client): + ctrl = make_controller() + expected = MagicMock() + mock_client.return_value.dispatch_event.return_value = expected + + result = ctrl.dispatch( + request=MagicMock(), + subscription=make_subscription(), + credentials={"k": "v"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected + mock_client.return_value.dispatch_event.assert_called_once() + + +class TestInvokeTriggerEvent: + @patch("core.trigger.provider.PluginTriggerClient") + def test_delegates_to_client(self, mock_client): + ctrl = make_controller() + expected = MagicMock() + mock_client.return_value.invoke_trigger_event.return_value = expected + + result = ctrl.invoke_trigger_event( + user_id="u1", + event_name="push", + parameters={}, + credentials={}, + credential_type=CredentialType.API_KEY, + subscription=make_subscription(), + request=MagicMock(), + payload={}, + ) + + assert result is expected + + +class TestSubscribeTrigger: + @patch("core.trigger.provider.PluginTriggerClient") + def test_returns_validated_subscription(self, mock_client): + ctrl = make_controller() + mock_client.return_value.subscribe.return_value.subscription = { + "expires_at": 123, + "endpoint": "https://e", + "properties": {}, + } + + result = ctrl.subscribe_trigger( + user_id="u1", + endpoint="https://e", + parameters={}, + credentials={}, + credential_type=CredentialType.API_KEY, + ) + + assert result.endpoint == "https://e" + + +class TestUnsubscribeTrigger: + @patch("core.trigger.provider.PluginTriggerClient") + def test_returns_validated_result(self, mock_client): + ctrl = make_controller() + mock_client.return_value.unsubscribe.return_value.subscription = {"success": True, "message": "ok"} + + result = ctrl.unsubscribe_trigger( + user_id="u1", + subscription=make_subscription(), + credentials={}, + credential_type=CredentialType.API_KEY, + ) + + assert result.success is True + + +class TestRefreshTrigger: + @patch("core.trigger.provider.PluginTriggerClient") + def test_uses_system_user_id(self, mock_client): + ctrl = make_controller() + mock_client.return_value.refresh.return_value.subscription = { + "expires_at": 456, + "endpoint": "https://e", + "properties": {}, + } + + ctrl.refresh_trigger(subscription=make_subscription(), credentials={}, credential_type=CredentialType.API_KEY) + + call_kwargs = mock_client.return_value.refresh.call_args[1] + assert call_kwargs["user_id"] == "system" diff --git a/api/tests/unit_tests/core/trigger/test_trigger_manager.py b/api/tests/unit_tests/core/trigger/test_trigger_manager.py new file mode 100644 index 0000000000..612be25ec9 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/test_trigger_manager.py @@ -0,0 +1,307 @@ +""" +Tests for core.trigger.trigger_manager.TriggerManager. + +Covers: icon URL construction, provider listing with error resilience, +double-check lock caching, error translation, EventIgnoreError -> cancelled, +and delegation to provider controller. +""" + +from __future__ import annotations + +from threading import Lock +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.entities.request import TriggerInvokeEventResponse +from core.plugin.impl.exc import PluginDaemonError, PluginNotFoundError +from core.trigger.errors import EventIgnoreError +from core.trigger.trigger_manager import TriggerManager +from models.provider_ids import TriggerProviderID +from tests.unit_tests.core.trigger.conftest import ( + VALID_PROVIDER_ID, + make_controller, + make_provider_entity, + make_subscription, +) + +PID = TriggerProviderID(VALID_PROVIDER_ID) +PID_STR = str(PID) + + +class TestGetTriggerPluginIcon: + @patch("core.trigger.trigger_manager.dify_config") + @patch("core.trigger.trigger_manager.PluginTriggerClient") + def test_builds_correct_url(self, mock_client, mock_config): + mock_config.CONSOLE_API_URL = "https://console.example.com" + provider = MagicMock() + provider.declaration.identity.icon = "my-icon.svg" + mock_client.return_value.fetch_trigger_provider.return_value = provider + + url = TriggerManager.get_trigger_plugin_icon("tenant-1", VALID_PROVIDER_ID) + + assert "tenant_id=tenant-1" in url + assert "filename=my-icon.svg" in url + assert url.startswith("https://console.example.com/console/api/workspaces/current/plugin/icon") + + +class TestListPluginTriggerProviders: + @patch("core.trigger.trigger_manager.PluginTriggerClient") + def test_wraps_entities_into_controllers(self, mock_client): + entity = MagicMock() + entity.declaration = make_provider_entity("p1") + entity.plugin_id = "plugin-1" + entity.plugin_unique_identifier = "uid-1" + entity.provider = VALID_PROVIDER_ID + mock_client.return_value.fetch_trigger_providers.return_value = [entity] + + controllers = TriggerManager.list_plugin_trigger_providers("tenant-1") + + assert len(controllers) == 1 + assert controllers[0].plugin_id == "plugin-1" + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + def test_skips_failing_providers(self, mock_client): + good = MagicMock() + good.declaration = make_provider_entity("good") + good.plugin_id = "good-plugin" + good.plugin_unique_identifier = "uid-good" + good.provider = VALID_PROVIDER_ID + + bad = MagicMock() + bad.declaration = make_provider_entity("bad") + bad.plugin_id = "bad-plugin" + bad.plugin_unique_identifier = "uid-bad" + bad.provider = "bad/format" # 2-part: fails TriggerProviderID validation + + mock_client.return_value.fetch_trigger_providers.return_value = [bad, good] + + controllers = TriggerManager.list_plugin_trigger_providers("tenant-1") + + assert len(controllers) == 1 + assert controllers[0].plugin_id == "good-plugin" + + +class TestGetTriggerProvider: + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_initializes_context_on_first_call(self, mock_ctx, mock_client): + # get() called 3 times: (1) try block, (2) after set, (3) under lock + mock_ctx.plugin_trigger_providers.get.side_effect = [LookupError, {}, {}] + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + provider = MagicMock() + provider.declaration = make_provider_entity() + provider.plugin_id = "p1" + provider.plugin_unique_identifier = "uid-1" + mock_client.return_value.fetch_trigger_provider.return_value = provider + + result = TriggerManager.get_trigger_provider("t1", PID) + + mock_ctx.plugin_trigger_providers.set.assert_called_once_with({}) + mock_ctx.plugin_trigger_providers_lock.set.assert_called_once() + assert result is not None + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_returns_cached_without_fetch(self, mock_ctx, mock_client): + cached = make_controller() + mock_ctx.plugin_trigger_providers.get.return_value = {PID_STR: cached} + + result = TriggerManager.get_trigger_provider("t1", PID) + + assert result is cached + mock_client.return_value.fetch_trigger_provider.assert_not_called() + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_double_check_lock_uses_cached_from_other_thread(self, mock_ctx, mock_client): + cached = make_controller() + mock_ctx.plugin_trigger_providers.get.side_effect = [ + {}, # first check misses + {PID_STR: cached}, # under-lock check hits + ] + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + + result = TriggerManager.get_trigger_provider("t1", PID) + + assert result is cached + mock_client.return_value.fetch_trigger_provider.assert_not_called() + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_fetches_and_caches_on_miss(self, mock_ctx, mock_client): + cache: dict = {} + mock_ctx.plugin_trigger_providers.get.return_value = cache + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + provider = MagicMock() + provider.declaration = make_provider_entity() + provider.plugin_id = "p1" + provider.plugin_unique_identifier = "uid-1" + mock_client.return_value.fetch_trigger_provider.return_value = provider + + result = TriggerManager.get_trigger_provider("t1", PID) + + assert result is not None + assert PID_STR in cache + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_none_fetch_raises_value_error(self, mock_ctx, mock_client): + mock_ctx.plugin_trigger_providers.get.return_value = {} + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + mock_client.return_value.fetch_trigger_provider.return_value = None + + with pytest.raises(ValueError): + TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/missing")) + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_plugin_not_found_becomes_value_error(self, mock_ctx, mock_client): + mock_ctx.plugin_trigger_providers.get.return_value = {} + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + mock_client.return_value.fetch_trigger_provider.side_effect = PluginNotFoundError("gone") + + with pytest.raises(ValueError): + TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss")) + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_plugin_daemon_error_propagates(self, mock_ctx, mock_client): + mock_ctx.plugin_trigger_providers.get.return_value = {} + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + mock_client.return_value.fetch_trigger_provider.side_effect = PluginDaemonError("test error") + + with pytest.raises(PluginDaemonError): + TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss")) + + +class TestListAllTriggerProviders: + @patch.object(TriggerManager, "list_plugin_trigger_providers") + def test_delegates_to_list_plugin(self, mock_list): + expected = [make_controller()] + mock_list.return_value = expected + + assert TriggerManager.list_all_trigger_providers("t1") is expected + mock_list.assert_called_once_with("t1") + + +class TestListTriggersByProvider: + @patch.object(TriggerManager, "get_trigger_provider") + def test_returns_provider_events(self, mock_get): + ctrl = make_controller() + mock_get.return_value = ctrl + + result = TriggerManager.list_triggers_by_provider("t1", PID) + + assert result == ctrl.get_events() + + +class TestInvokeTriggerEvent: + def _args(self): + return { + "tenant_id": "t1", + "user_id": "u1", + "provider_id": PID, + "event_name": "on_push", + "parameters": {"branch": "main"}, + "credentials": {"token": "abc"}, + "credential_type": CredentialType.API_KEY, + "subscription": make_subscription(), + "request": MagicMock(), + "payload": {"action": "push"}, + } + + @patch.object(TriggerManager, "get_trigger_provider") + def test_returns_invoke_response(self, mock_get): + ctrl = MagicMock() + expected = TriggerInvokeEventResponse(variables={"v": "1"}, cancelled=False) + ctrl.invoke_trigger_event.return_value = expected + mock_get.return_value = ctrl + + result = TriggerManager.invoke_trigger_event(**self._args()) + + assert result is expected + assert result.cancelled is False + + @patch.object(TriggerManager, "get_trigger_provider") + def test_event_ignore_returns_cancelled(self, mock_get): + ctrl = MagicMock() + ctrl.invoke_trigger_event.side_effect = EventIgnoreError("skip") + mock_get.return_value = ctrl + + result = TriggerManager.invoke_trigger_event(**self._args()) + + assert result.cancelled is True + assert result.variables == {} + + @patch.object(TriggerManager, "get_trigger_provider") + def test_other_errors_propagate(self, mock_get): + ctrl = MagicMock() + ctrl.invoke_trigger_event.side_effect = RuntimeError("boom") + mock_get.return_value = ctrl + + with pytest.raises(RuntimeError, match="boom"): + TriggerManager.invoke_trigger_event(**self._args()) + + +class TestSubscribeTrigger: + @patch.object(TriggerManager, "get_trigger_provider") + def test_delegates_with_correct_args(self, mock_get): + ctrl = MagicMock() + expected = make_subscription() + ctrl.subscribe_trigger.return_value = expected + mock_get.return_value = ctrl + + result = TriggerManager.subscribe_trigger( + tenant_id="t1", + user_id="u1", + provider_id=PID, + endpoint="https://hook.test", + parameters={"f": "all"}, + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected + ctrl.subscribe_trigger.assert_called_once() + + +class TestUnsubscribeTrigger: + @patch.object(TriggerManager, "get_trigger_provider") + def test_delegates_with_correct_args(self, mock_get): + ctrl = MagicMock() + expected = MagicMock() + ctrl.unsubscribe_trigger.return_value = expected + mock_get.return_value = ctrl + sub = make_subscription() + + result = TriggerManager.unsubscribe_trigger( + tenant_id="t1", + user_id="u1", + provider_id=PID, + subscription=sub, + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected + + +class TestRefreshTrigger: + @patch.object(TriggerManager, "get_trigger_provider") + def test_delegates_with_correct_args(self, mock_get): + ctrl = MagicMock() + expected = make_subscription() + ctrl.refresh_trigger.return_value = expected + mock_get.return_value = ctrl + + result = TriggerManager.refresh_trigger( + tenant_id="t1", + provider_id=PID, + subscription=make_subscription(), + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected diff --git a/api/tests/unit_tests/core/trigger/utils/__init__.py b/api/tests/unit_tests/core/trigger/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/trigger/utils/test_utils_encryption.py b/api/tests/unit_tests/core/trigger/utils/test_utils_encryption.py new file mode 100644 index 0000000000..8804526e2e --- /dev/null +++ b/api/tests/unit_tests/core/trigger/utils/test_utils_encryption.py @@ -0,0 +1,62 @@ +"""Tests for core.trigger.utils.encryption — masking logic and cache key generation.""" + +from __future__ import annotations + +from core.entities.provider_entities import ProviderConfig +from core.tools.entities.common_entities import I18nObject +from core.trigger.utils.encryption import ( + TriggerProviderCredentialsCache, + TriggerProviderOAuthClientParamsCache, + TriggerProviderPropertiesCache, + masked_credentials, +) + + +def _make_schema(name: str, field_type: str = "secret-input") -> ProviderConfig: + return ProviderConfig( + name=name, + label=I18nObject(en_US=name, zh_Hans=name), + type=field_type, + ) + + +class TestMaskedCredentials: + def test_short_secret_fully_masked(self): + schema = [_make_schema("key", "secret-input")] + result = masked_credentials(schema, {"key": "ab"}) + assert result["key"] == "**" + + def test_long_secret_partially_masked(self): + schema = [_make_schema("key", "secret-input")] + result = masked_credentials(schema, {"key": "abcdef"}) + assert result["key"].startswith("ab") + assert result["key"].endswith("ef") + assert "**" in result["key"] + + def test_non_secret_field_unchanged(self): + schema = [_make_schema("host", "text-input")] + result = masked_credentials(schema, {"host": "example.com"}) + assert result["host"] == "example.com" + + def test_unknown_key_passes_through(self): + result = masked_credentials([], {"unknown": "value"}) + assert result["unknown"] == "value" + + +class TestCacheKeyGeneration: + def test_credentials_cache_key_contains_ids(self): + cache = TriggerProviderCredentialsCache(tenant_id="t1", provider_id="p1", credential_id="c1") + assert "t1" in cache.cache_key + assert "p1" in cache.cache_key + assert "c1" in cache.cache_key + + def test_oauth_client_cache_key_contains_ids(self): + cache = TriggerProviderOAuthClientParamsCache(tenant_id="t1", provider_id="p1") + assert "t1" in cache.cache_key + assert "p1" in cache.cache_key + + def test_properties_cache_key_contains_ids(self): + cache = TriggerProviderPropertiesCache(tenant_id="t1", provider_id="p1", subscription_id="s1") + assert "t1" in cache.cache_key + assert "p1" in cache.cache_key + assert "s1" in cache.cache_key diff --git a/api/tests/unit_tests/core/trigger/utils/test_utils_endpoint.py b/api/tests/unit_tests/core/trigger/utils/test_utils_endpoint.py new file mode 100644 index 0000000000..e5879aea0a --- /dev/null +++ b/api/tests/unit_tests/core/trigger/utils/test_utils_endpoint.py @@ -0,0 +1,31 @@ +"""Tests for core.trigger.utils.endpoint — URL generation.""" + +from __future__ import annotations + +from unittest.mock import patch + +from yarl import URL + +from core.trigger.utils import endpoint + + +class TestGeneratePluginTriggerEndpointUrl: + def test_builds_correct_url(self): + with patch.object(endpoint, "base_url", URL("https://api.example.com")): + url = endpoint.generate_plugin_trigger_endpoint_url("endpoint-123") + + assert url == "https://api.example.com/triggers/plugin/endpoint-123" + + +class TestGenerateWebhookTriggerEndpoint: + def test_non_debug_url(self): + with patch.object(endpoint, "base_url", URL("https://api.example.com")): + url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=False) + + assert url == "https://api.example.com/triggers/webhook/sub-456" + + def test_debug_url(self): + with patch.object(endpoint, "base_url", URL("https://api.example.com")): + url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=True) + + assert url == "https://api.example.com/triggers/webhook-debug/sub-456" diff --git a/api/tests/unit_tests/core/trigger/utils/test_utils_locks.py b/api/tests/unit_tests/core/trigger/utils/test_utils_locks.py new file mode 100644 index 0000000000..4fa202b164 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/utils/test_utils_locks.py @@ -0,0 +1,23 @@ +"""Tests for core.trigger.utils.locks — Redis lock key builders.""" + +from __future__ import annotations + +from core.trigger.utils.locks import build_trigger_refresh_lock_key, build_trigger_refresh_lock_keys + + +class TestBuildTriggerRefreshLockKey: + def test_correct_format(self): + key = build_trigger_refresh_lock_key("tenant-1", "sub-1") + + assert key == "trigger_provider_refresh_lock:tenant-1_sub-1" + + +class TestBuildTriggerRefreshLockKeys: + def test_maps_over_pairs(self): + pairs = [("t1", "s1"), ("t2", "s2")] + + keys = build_trigger_refresh_lock_keys(pairs) + + assert len(keys) == 2 + assert keys[0] == "trigger_provider_refresh_lock:t1_s1" + assert keys[1] == "trigger_provider_refresh_lock:t2_s2" diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 0df4927697..22792eb5b3 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,8 +4,10 @@ from unittest.mock import MagicMock, patch import pytest +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from dify_graph.variables.variables import StringVariable class StubCoordinator: @@ -278,3 +280,17 @@ class TestGraphRuntimeState: assert restored_execution.started is True assert new_stub.state == "configured" + + def test_snapshot_restore_preserves_updated_conversation_variable(self): + variable_pool = VariablePool( + conversation_variables=[StringVariable(name="session_name", value="before")], + ) + variable_pool.add((CONVERSATION_VARIABLE_NODE_ID, "session_name"), "after") + + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + snapshot = state.dumps() + restored = GraphRuntimeState.from_snapshot(snapshot) + + restored_value = restored.variable_pool.get((CONVERSATION_VARIABLE_NODE_ID, "session_name")) + assert restored_value is not None + assert restored_value.value == "after" diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py index d2743cbbbe..b93f18c5bd 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py @@ -4,15 +4,13 @@ from typing import Any import pytest -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph.validation import GraphValidationError from dify_graph.nodes import NodeType from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params def _build_iteration_graph(node_id: str) -> dict[str, Any]: @@ -53,14 +51,14 @@ def _build_loop_graph(node_id: str) -> dict[str, Any]: def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + user_from="account", + invoke_from="debugger", call_depth=0, ) graph_runtime_state = GraphRuntimeState( diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index c9134c4543..b98d56147e 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -6,7 +6,6 @@ from dataclasses import dataclass import pytest -from core.app.entities.app_invoke_entities import InvokeFrom from dify_graph.entities import GraphInitParams from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeType from dify_graph.graph import Graph @@ -15,7 +14,7 @@ from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params class _TestNodeData(BaseNodeData): @@ -92,14 +91,14 @@ class _SimpleNodeFactory: @pytest.fixture def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: graph_config: dict[str, object] = {"edges": [], "nodes": []} - init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + user_from="account", + invoke_from="service-api", call_depth=0, ) variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 352e270fe4..819fd67f9d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -32,6 +32,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: node.execution_id = "execution-id" node.node_type = NodeType.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -52,6 +53,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: node.execution_id = "execution-id" node.node_type = NodeType.QUESTION_CLASSIFIER node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -72,6 +74,7 @@ def test_non_llm_node_is_ignored() -> None: node.execution_id = "execution-id" node.node_type = NodeType.START node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node._model_instance = object() result_event = _build_succeeded_event() @@ -88,6 +91,7 @@ def test_quota_error_is_handled_in_layer() -> None: node.execution_id = "execution-id" node.node_type = NodeType.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -109,6 +113,7 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: node.execution_id = "execution-id" node.node_type = NodeType.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py index 244bb2315d..f886ae1c2b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py @@ -8,6 +8,7 @@ for workflows containing nodes that require third-party services. import pytest from dify_graph.enums import NodeType +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from .test_table_runner import TableTestRunner, WorkflowTestCase @@ -199,22 +200,19 @@ def test_mock_config_builder(): def test_mock_factory_node_type_detection(): """Test that MockNodeFactory correctly identifies nodes to mock.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from dify_graph.entities import GraphInitParams + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom from .test_mock_factory import MockNodeFactory - graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", + graph_init_params = build_test_graph_init_params( workflow_id="test", graph_config={}, + tenant_id="test", + app_id="test", user_id="test", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), @@ -309,11 +307,9 @@ def test_workflow_without_auto_mock(): def test_register_custom_mock_node(): """Test registering a custom mock implementation for a node type.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from dify_graph.entities import GraphInitParams + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.nodes.template_transform import TemplateTransformNode from dify_graph.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom from .test_mock_factory import MockNodeFactory @@ -323,15 +319,14 @@ def test_register_custom_mock_node(): # Custom mock implementation pass - graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", + graph_init_params = build_test_graph_init_params( workflow_id="test", graph_config={}, + tenant_id="test", + app_id="test", user_id="test", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 02ca827d6d..765c4deba3 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -3,8 +3,8 @@ import time from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.graph_init_params import GraphInitParams +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams from dify_graph.entities.pause_reason import SchedulingPause from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig @@ -20,7 +20,6 @@ from dify_graph.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, G from dify_graph.nodes.start.start_node import StartNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.variables import IntegerVariable, StringVariable -from models.enums import UserFrom def test_abort_command(): @@ -41,13 +40,17 @@ def test_abort_command(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, @@ -151,13 +154,17 @@ def test_pause_command(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, @@ -207,13 +214,17 @@ def test_update_variables_command_updates_pool(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py index 84e033156d..d54f0be190 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py @@ -21,6 +21,7 @@ from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -73,11 +74,11 @@ def _build_llm_node( def _build_graph(runtime_state: GraphRuntimeState) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 695e99c1cf..538f53c603 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -4,7 +4,6 @@ from collections.abc import Iterable from unittest import mock from unittest.mock import MagicMock -from dify_graph.entities import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunPausedEvent, @@ -35,6 +34,7 @@ from dify_graph.repositories.human_input_form_repository import HumanInputFormEn from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -47,11 +47,11 @@ def _build_branching_graph( graph_runtime_state: GraphRuntimeState | None = None, ) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index 0275062c41..36bba6deb6 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -3,7 +3,6 @@ import time from unittest import mock from unittest.mock import MagicMock -from dify_graph.entities import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunPausedEvent, @@ -34,6 +33,7 @@ from dify_graph.repositories.human_input_form_repository import HumanInputFormEn from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -46,11 +46,11 @@ def _build_llm_human_llm_graph( graph_runtime_state: GraphRuntimeState | None = None, ) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py index fbcb8d7155..8da179c15e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py @@ -1,7 +1,6 @@ import time from unittest import mock -from dify_graph.entities import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunStartedEvent, @@ -29,6 +28,7 @@ from dify_graph.nodes.start.start_node import StartNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.utils.condition.entities import Condition +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -37,15 +37,10 @@ from .test_table_runner import TableTestRunner, WorkflowTestCase def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", + graph_init_params = build_test_graph_init_params( graph_config=graph_config, - user_id="user", user_from="account", invoke_from="debugger", - call_depth=0, ) variable_pool = VariablePool( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index 95bf4bff60..8c8e5977c8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -2,13 +2,7 @@ Simple test to verify MockNodeFactory works with iteration nodes. """ -import sys -from pathlib import Path - -# Add api directory to path -api_dir = Path(__file__).parent.parent.parent.parent.parent.parent -sys.path.insert(0, str(api_dir)) - +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import NodeType from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory @@ -16,20 +10,23 @@ from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNo def test_mock_factory_registers_iteration_node(): """Test that MockNodeFactory has iteration node registered.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams from dify_graph.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom # Create a MockNodeFactory instance graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -65,10 +62,9 @@ def test_mock_factory_registers_iteration_node(): def test_mock_iteration_node_preserves_config(): """Test that MockIterationNode preserves mock configuration.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams from dify_graph.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode # Create mock config @@ -76,13 +72,17 @@ def test_mock_iteration_node_preserves_config(): # Create minimal graph init params graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) @@ -127,10 +127,9 @@ def test_mock_iteration_node_preserves_config(): def test_mock_loop_node_preserves_config(): """Test that MockLoopNode preserves mock configuration.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams from dify_graph.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode # Create mock config @@ -138,13 +137,17 @@ def test_mock_loop_node_preserves_config(): # Create minimal graph init params graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index ad71227205..34e714a227 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -22,8 +22,13 @@ from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode from dify_graph.nodes.llm import LLMNode from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.parameter_extractor import ParameterExtractorNode +from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol from dify_graph.nodes.question_classifier import QuestionClassifierNode from dify_graph.nodes.template_transform import TemplateTransformNode +from dify_graph.nodes.template_transform.template_renderer import ( + Jinja2TemplateRenderer, + TemplateRenderError, +) from dify_graph.nodes.tool import ToolNode if TYPE_CHECKING: @@ -33,6 +38,18 @@ if TYPE_CHECKING: from .test_mock_config import MockConfig +class _TestJinja2Renderer(Jinja2TemplateRenderer): + """Simple Jinja2 renderer for tests (avoids code executor).""" + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + from jinja2 import Template as _Jinja2Template + + try: + return _Jinja2Template(template).render(**variables) + except Exception as exc: # pragma: no cover - pass through as contract error + raise TemplateRenderError(str(exc)) from exc + + class MockNodeMixin: """Mixin providing common mock functionality.""" @@ -49,6 +66,18 @@ class MockNodeMixin: kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) + # LLM-like nodes now require an http_client; provide a mock by default for tests. + kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) + + # Ensure TemplateTransformNode receives a renderer now required by constructor + if isinstance(self, TemplateTransformNode): + kwargs.setdefault("template_renderer", _TestJinja2Renderer()) + + # Provide default tool_file_manager_factory for ToolNode subclasses + from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles + + if isinstance(self, _ToolNode): + kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) super().__init__( id=id, @@ -583,13 +612,9 @@ class MockIterationNode(MockNodeMixin, IterationNode): # Create GraphInitParams from node attributes graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) @@ -659,13 +684,9 @@ class MockLoopNode(MockNodeMixin, LoopNode): # Create GraphInitParams from node attributes graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index 0942d15073..1550dca402 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -6,6 +6,7 @@ to ensure they work correctly with the TableTestRunner. """ from configs import dify_config +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.nodes.code.limits import CodeNodeLimits from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig @@ -44,13 +45,17 @@ class TestMockTemplateTransformNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -103,13 +108,17 @@ class TestMockTemplateTransformNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -163,13 +172,17 @@ class TestMockTemplateTransformNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -221,13 +234,17 @@ class TestMockTemplateTransformNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -286,13 +303,17 @@ class TestMockCodeNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -348,13 +369,17 @@ class TestMockCodeNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -418,13 +443,17 @@ class TestMockCodeNode: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -490,13 +519,17 @@ class TestMockNodeFactory: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -531,13 +564,17 @@ class TestMockNodeFactory: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -582,13 +619,17 @@ class TestMockNodeFactory: # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py index 975a48705a..693cdf9276 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py @@ -3,12 +3,8 @@ Simple test to validate the auto-mock system without external dependencies. """ import sys -from pathlib import Path - -# Add api directory to path -api_dir = Path(__file__).parent.parent.parent.parent.parent.parent -sys.path.insert(0, str(api_dir)) +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import NodeType from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory @@ -101,21 +97,24 @@ def test_node_mock_config(): def test_mock_factory_detection(): """Test MockNodeFactory node type detection.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams from dify_graph.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom print("Testing MockNodeFactory detection...") graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -154,21 +153,24 @@ def test_mock_factory_detection(): def test_mock_factory_registration(): """Test registering and unregistering mock node types.""" - from core.app.entities.app_invoke_entities import InvokeFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams from dify_graph.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom print("Testing MockNodeFactory registration...") graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index e269263bde..e681b39cc7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol -from dify_graph.entities import GraphInitParams from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -32,6 +31,7 @@ from dify_graph.repositories.human_input_form_repository import ( from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params class PauseStateStore(Protocol): @@ -126,11 +126,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py index 910292b52c..60167c0441 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.entities import GraphInitParams from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -39,6 +38,7 @@ from dify_graph.repositories.human_input_form_repository import ( from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, NodeMockConfig from .test_mock_nodes import MockLLMNode @@ -129,11 +129,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index 1f456175e9..0ac9d6618d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -12,10 +12,9 @@ import time from unittest.mock import MagicMock, patch from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig @@ -30,7 +29,7 @@ from dify_graph.node_events import NodeRunResult, StreamCompletedEvent from dify_graph.nodes.llm.node import LLMNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params from .test_table_runner import TableTestRunner @@ -87,11 +86,11 @@ def test_parallel_streaming_workflow(): graph_config = workflow_config.get("graph", {}) # Create graph initialization parameters - init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", + init_params = build_test_graph_init_params( workflow_id="test_workflow", graph_config=graph_config, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, @@ -100,8 +99,8 @@ def test_parallel_streaming_workflow(): # Create variable pool with system variables system_variables = SystemVariable( - user_id=init_params.user_id, - app_id=init_params.app_id, + user_id="test_user", + app_id="test_app", workflow_id=init_params.workflow_id, files=[], query="Tell me about yourself", # User query diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py index e5a9a29a1f..7328ce443f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.entities import GraphInitParams from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -40,6 +39,7 @@ from dify_graph.repositories.human_input_form_repository import ( from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, NodeMockConfig from .test_mock_nodes import MockLLMNode @@ -121,11 +121,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py index 183d589e2b..15a7de3c52 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -3,7 +3,6 @@ import time from typing import Any from unittest.mock import MagicMock -from dify_graph.entities import GraphInitParams from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -30,6 +29,7 @@ from dify_graph.repositories.human_input_form_repository import ( from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params def _build_runtime_state() -> GraphRuntimeState: @@ -79,11 +79,11 @@ def _build_human_input_graph( form_repository: HumanInputFormRepository, ) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = GraphInitParams( - tenant_id="tenant", - app_id="app", + params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="service-api", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index ee36c976f0..767a8f60ce 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -12,19 +12,21 @@ This module provides a robust table-driven testing framework with support for: import logging import time -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import Any +from typing import Any, cast +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.yaml_utils import _load_yaml_file from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.graph_init_params import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events import ( GraphEngineEvent, GraphRunStartedEvent, @@ -48,6 +50,47 @@ from .test_mock_factory import MockNodeFactory logger = logging.getLogger(__name__) +class _TableTestChildEngineBuilder: + def __init__(self, *, use_mock_factory: bool, mock_config: MockConfig | None) -> None: + self._use_mock_factory = use_mock_factory + self._mock_config = mock_config + + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> GraphEngine: + if self._use_mock_factory: + node_factory = MockNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + mock_config=self._mock_config, + ) + else: + node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + + child_graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id) + if not child_graph: + raise ValueError("child graph not found") + + child_engine = GraphEngine( + workflow_id=workflow_id, + graph=child_graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + child_engine_builder=self, + ) + for layer in layers: + child_engine.layer(cast(GraphEngineLayer, layer)) + return child_engine + + @dataclass class WorkflowTestCase: """Represents a single test case for table-driven testing.""" @@ -149,19 +192,23 @@ class WorkflowRunner: raise ValueError("Fixture missing workflow.graph configuration") graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config=graph_config, - user_id="test_user", - user_from="account", - invoke_from="debugger", # Set to debugger to avoid conversation_id requirement + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, # Set to debugger to avoid conversation_id requirement + } + }, call_depth=0, ) system_variables = SystemVariable( - user_id=graph_init_params.user_id, - app_id=graph_init_params.app_id, + user_id="test_user", + app_id="test_app", workflow_id=graph_init_params.workflow_id, files=[], query=query, @@ -315,6 +362,10 @@ class TableTestRunner: scale_up_threshold=self.graph_engine_scale_up_threshold, scale_down_idle_time=self.graph_engine_scale_down_idle_time, ), + child_engine_builder=_TableTestChildEngineBuilder( + use_mock_factory=test_case.use_auto_mock, + mock_config=test_case.mock_config, + ), ) # Execute and collect events diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 842b9e2178..f0d80af1ed 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,16 +2,15 @@ import time import uuid from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.nodes.answer.answer_node import AnswerNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from extensions.ext_database import db -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params def test_execute_answer(): @@ -36,11 +35,11 @@ def test_execute_answer(): ], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index b100c7c02c..db096b1aed 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,3 +1,4 @@ +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from dify_graph.nodes.datasource.datasource_node import DatasourceNode @@ -28,13 +29,17 @@ class _GraphState: class _GraphParams: - tenant_id = "t1" - app_id = "app-1" workflow_id = "wf-1" graph_config = {} - user_id = "u1" - user_from = "account" - invoke_from = "debugger" + run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "t1", + "app_id": "app-1", + "user_id": "u1", + "user_from": "account", + "invoke_from": "debugger", + } + } call_depth = 0 diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 8ca973bf9b..5e34bf1d94 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -4,17 +4,16 @@ from typing import Any import httpx import pytest -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager -from dify_graph.entities import GraphInitParams from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file.file_manager import file_manager from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout, Response from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=10, @@ -99,11 +98,11 @@ def _build_http_node( ], "edges": [], } - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 3b2a81ccef..55aa62a1c0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -9,6 +9,7 @@ import pytest from pydantic import ValidationError from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.node_events import PauseRequestedEvent from dify_graph.node_events.node import StreamCompletedEvent from dify_graph.nodes.human_input.entities import ( @@ -314,13 +315,17 @@ class TestHumanInputNodeVariableResolution: variable_pool.add(("start", "name"), "Jane Doe") runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -384,13 +389,17 @@ class TestHumanInputNodeVariableResolution: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -439,13 +448,17 @@ class TestHumanInputNodeVariableResolution: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user-123", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user-123", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -550,13 +563,17 @@ class TestHumanInputNodeRenderedContent: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index 1472cdec36..1fea19e795 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,8 +1,8 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.graph_init_params import GraphInitParams +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams from dify_graph.enums import NodeType from dify_graph.graph_events import ( NodeRunHumanInputFormFilledEvent, @@ -14,7 +14,6 @@ from dify_graph.nodes.human_input.human_input_node import HumanInputNode from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now -from models.enums import UserFrom class _FakeFormRepository: @@ -32,13 +31,17 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# start_at=0.0, ) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) @@ -92,13 +95,17 @@ def _build_timeout_node() -> HumanInputNode: start_at=0.0, ) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py new file mode 100644 index 0000000000..2eb4feef5f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -0,0 +1,100 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +import pytest + +from dify_graph.entities import GraphInitParams +from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError +from dify_graph.nodes.iteration.iteration_node import IterationNode +from dify_graph.runtime import ( + ChildEngineBuilderNotConfiguredError, + ChildGraphNotFoundError, + GraphRuntimeState, + VariablePool, +) +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params + + +class _MissingGraphBuilder: + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> object: + raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") + + +def _build_runtime_state() -> GraphRuntimeState: + return GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default(), user_inputs={}), + start_at=0.0, + ) + + +def _build_iteration_node( + *, + graph_config: Mapping[str, Any], + runtime_state: GraphRuntimeState, + start_node_id: str, +) -> IterationNode: + init_params = build_test_graph_init_params(graph_config=graph_config) + return IterationNode( + id="iteration-node", + config={ + "id": "iteration-node", + "data": { + "type": "iteration", + "title": "Iteration", + "iterator_selector": ["start", "items"], + "output_selector": ["iteration-node", "output"], + "start_node_id": start_node_id, + }, + }, + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + +def test_graph_runtime_state_raises_specific_error_when_child_builder_is_missing(): + runtime_state = _build_runtime_state() + graph_init_params = build_test_graph_init_params() + + with pytest.raises(ChildEngineBuilderNotConfiguredError): + runtime_state.create_child_engine( + workflow_id="workflow", + graph_init_params=graph_init_params, + graph_runtime_state=_build_runtime_state(), + graph_config={}, + root_node_id="root", + ) + + +def test_iteration_node_only_translates_child_graph_not_found_error(): + runtime_state = _build_runtime_state() + runtime_state.bind_child_engine_builder(_MissingGraphBuilder()) + node = _build_iteration_node( + graph_config={"nodes": [{"id": "present-node"}], "edges": []}, + runtime_state=runtime_state, + start_node_id="missing-node", + ) + + with pytest.raises(IterationGraphNotFoundError): + node._create_graph_engine(index=0, item="item") + + +def test_iteration_node_propagates_non_graph_not_found_errors(): + runtime_state = _build_runtime_state() + node = _build_iteration_node( + graph_config={"nodes": [{"id": "start-node"}], "edges": []}, + runtime_state=runtime_state, + start_node_id="start-node", + ) + + with pytest.raises(ChildEngineBuilderNotConfiguredError): + node._create_graph_engine(index=0, item="item") diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index 83a5db2fae..8116fc8b3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -4,8 +4,7 @@ from unittest.mock import Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities import GraphInitParams +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.enums import SystemVariableKey, WorkflowNodeExecutionStatus from dify_graph.nodes.knowledge_index.entities import KnowledgeIndexNodeData from dify_graph.nodes.knowledge_index.exc import KnowledgeIndexNodeError @@ -15,17 +14,17 @@ from dify_graph.repositories.summary_index_service_protocol import SummaryIndexS from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables.segments import StringSegment -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def mock_graph_init_params(): """Create mock GraphInitParams.""" - return GraphInitParams( - tenant_id=str(uuid.uuid4()), - app_id=str(uuid.uuid4()), + return build_test_graph_init_params( workflow_id=str(uuid.uuid4()), graph_config={}, + tenant_id=str(uuid.uuid4()), + app_id=str(uuid.uuid4()), user_id=str(uuid.uuid4()), user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 98246ccb2f..e929d652fd 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -4,12 +4,13 @@ from unittest.mock import Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities import GraphInitParams +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.nodes.knowledge_retrieval.entities import ( + Condition, KnowledgeRetrievalNodeData, + MetadataFilteringCondition, MultipleRetrievalConfig, RerankingModelConfig, SingleRetrievalConfig, @@ -20,17 +21,17 @@ from dify_graph.repositories.rag_retrieval_protocol import RAGRetrievalProtocol, from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables import StringSegment -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def mock_graph_init_params(): """Create mock GraphInitParams.""" - return GraphInitParams( - tenant_id=str(uuid.uuid4()), - app_id=str(uuid.uuid4()), + return build_test_graph_init_params( workflow_id=str(uuid.uuid4()), graph_config={}, + tenant_id=str(uuid.uuid4()), + app_id=str(uuid.uuid4()), user_id=str(uuid.uuid4()), user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -111,7 +112,6 @@ class TestKnowledgeRetrievalNode: # Assert assert node.id == node_id assert node._rag_retrieval == mock_rag_retrieval - assert node._llm_file_saver is not None def test_run_with_no_query_or_attachment( self, @@ -206,6 +206,7 @@ class TestKnowledgeRetrievalNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert "result" in result.outputs assert mock_rag_retrieval.knowledge_retrieval.called + mock_source.model_dump.assert_called_once_with(by_alias=True) def test_run_with_query_variable_multiple_mode( self, @@ -593,3 +594,106 @@ class TestFetchDatasetRetriever: # Assert assert version == "1" + + def test_resolve_metadata_filtering_conditions_templates( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + ): + """_resolve_metadata_filtering_conditions should expand {{#...#}} and keep numbers/None unchanged.""" + # Arrange + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": { + "title": "Knowledge Retrieval", + "type": "knowledge-retrieval", + "dataset_ids": [str(uuid.uuid4())], + "retrieval_mode": "multiple", + }, + } + # Variable in pool used by template + mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme")) + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + conditions = MetadataFilteringCondition( + logical_operator="and", + conditions=[ + Condition(name="document_name", comparison_operator="is", value="{{#start.query#}}"), + Condition(name="tags", comparison_operator="in", value=["x", "{{#start.query#}}"]), + Condition(name="year", comparison_operator="=", value=2025), + ], + ) + + # Act + resolved = node._resolve_metadata_filtering_conditions(conditions) + + # Assert + assert resolved.logical_operator == "and" + assert resolved.conditions[0].value == "readme" + assert isinstance(resolved.conditions[1].value, list) + assert resolved.conditions[1].value[1] == "readme" + assert resolved.conditions[2].value == 2025 + + def test_fetch_passes_resolved_metadata_conditions( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + ): + """_fetch_dataset_retriever should pass resolved metadata conditions into request.""" + # Arrange + query = "hi" + variables = {"query": query} + mock_graph_runtime_state.variable_pool.add(["start", "q"], StringSegment(value="readme")) + + node_data = KnowledgeRetrievalNodeData( + title="Knowledge Retrieval", + type="knowledge-retrieval", + dataset_ids=[str(uuid.uuid4())], + retrieval_mode="multiple", + multiple_retrieval_config=MultipleRetrievalConfig( + top_k=4, + score_threshold=0.0, + reranking_mode="reranking_model", + reranking_enable=True, + reranking_model=RerankingModelConfig(provider="cohere", model="rerank-v2"), + ), + metadata_filtering_mode="manual", + metadata_filtering_conditions=MetadataFilteringCondition( + logical_operator="and", + conditions=[ + Condition(name="document_name", comparison_operator="is", value="{{#start.q#}}"), + ], + ), + ) + + node_id = str(uuid.uuid4()) + config = {"id": node_id, "data": node_data.model_dump()} + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + mock_rag_retrieval.knowledge_retrieval.return_value = [] + mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() + + # Act + node._fetch_dataset_retriever(node_data=node_data, variables=variables) + + # Assert the passed request has resolved value + call_args = mock_rag_retrieval.knowledge_retrieval.call_args + request = call_args[1]["request"] + assert request.metadata_filtering_conditions is not None + assert request.metadata_filtering_conditions.conditions[0].value == "readme" diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index 65228df517..25760ba352 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,14 +1,13 @@ from unittest.mock import MagicMock import pytest -from dify_graph.graph_engine.entities.graph import Graph -from dify_graph.graph_engine.entities.graph_init_params import GraphInitParams -from dify_graph.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.nodes.list_operator.node import ListOperatorNode +from dify_graph.runtime import GraphRuntimeState from dify_graph.variables import ArrayNumberSegment, ArrayStringSegment -from models.workflow import WorkflowType class TestListOperatorNode: @@ -22,43 +21,40 @@ class TestListOperatorNode: mock_state.variable_pool = mock_variable_pool return mock_state - @pytest.fixture - def mock_graph(self): - """Create mock Graph.""" - return MagicMock(spec=Graph) - @pytest.fixture def graph_init_params(self): """Create GraphInitParams fixture.""" return GraphInitParams( - tenant_id="test", - app_id="test", - workflow_type=WorkflowType.WORKFLOW, workflow_id="test", graph_config={}, - user_id="test", - user_from="test", - invoke_from="test", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": "test", + "invoke_from": "test", + } + }, call_depth=0, ) @pytest.fixture - def list_operator_node_factory(self, graph_init_params, mock_graph, mock_graph_runtime_state): + def list_operator_node_factory(self, graph_init_params, mock_graph_runtime_state): """Factory fixture for creating ListOperatorNode instances.""" def _create_node(config, mock_variable): mock_graph_runtime_state.variable_pool.get.return_value = mock_variable return ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) return _create_node - def test_node_initialization(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_node_initialization(self, mock_graph_runtime_state, graph_init_params): """Test node initializes correctly.""" config = { "title": "List Operator", @@ -70,9 +66,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -101,7 +96,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana", "cherry"] - def test_run_with_empty_array(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_empty_array(self, mock_graph_runtime_state, graph_init_params): """Test with empty array.""" config = { "title": "Test", @@ -116,9 +111,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -129,7 +123,7 @@ class TestListOperatorNode: assert result.outputs["first_record"] is None assert result.outputs["last_record"] is None - def test_run_with_filter_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_contains(self, mock_graph_runtime_state, graph_init_params): """Test filter with contains condition.""" config = { "title": "Test", @@ -148,9 +142,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -159,7 +152,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "pineapple"] - def test_run_with_filter_not_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_not_contains(self, mock_graph_runtime_state, graph_init_params): """Test filter with not contains condition.""" config = { "title": "Test", @@ -178,9 +171,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -189,7 +181,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["banana", "cherry"] - def test_run_with_number_filter_greater_than(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_greater_than(self, mock_graph_runtime_state, graph_init_params): """Test filter with greater than condition on numbers.""" config = { "title": "Test", @@ -208,9 +200,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -219,7 +210,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [7, 9, 11] - def test_run_with_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_order_ascending(self, mock_graph_runtime_state, graph_init_params): """Test ordering in ascending order.""" config = { "title": "Test", @@ -237,9 +228,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -248,7 +238,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana", "cherry"] - def test_run_with_order_descending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_order_descending(self, mock_graph_runtime_state, graph_init_params): """Test ordering in descending order.""" config = { "title": "Test", @@ -266,9 +256,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -277,7 +266,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["cherry", "banana", "apple"] - def test_run_with_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_limit(self, mock_graph_runtime_state, graph_init_params): """Test with limit enabled.""" config = { "title": "Test", @@ -295,9 +284,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -306,7 +294,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana"] - def test_run_with_filter_order_and_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_order_and_limit(self, mock_graph_runtime_state, graph_init_params): """Test with filter, order, and limit combined.""" config = { "title": "Test", @@ -331,9 +319,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -342,7 +329,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [9, 8, 7] - def test_run_with_variable_not_found(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_variable_not_found(self, mock_graph_runtime_state, graph_init_params): """Test when variable is not found.""" config = { "title": "Test", @@ -356,9 +343,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -367,7 +353,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Variable not found" in result.error - def test_run_with_first_and_last_record(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_first_and_last_record(self, mock_graph_runtime_state, graph_init_params): """Test first_record and last_record outputs.""" config = { "title": "Test", @@ -382,9 +368,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -394,7 +379,7 @@ class TestListOperatorNode: assert result.outputs["first_record"] == "first" assert result.outputs["last_record"] == "last" - def test_run_with_filter_startswith(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_startswith(self, mock_graph_runtime_state, graph_init_params): """Test filter with startswith condition.""" config = { "title": "Test", @@ -413,9 +398,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -424,7 +408,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "application"] - def test_run_with_filter_endswith(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_endswith(self, mock_graph_runtime_state, graph_init_params): """Test filter with endswith condition.""" config = { "title": "Test", @@ -443,9 +427,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -454,7 +437,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "pineapple", "table"] - def test_run_with_number_filter_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_equals(self, mock_graph_runtime_state, graph_init_params): """Test number filter with equals condition.""" config = { "title": "Test", @@ -473,9 +456,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -484,7 +466,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [5, 5] - def test_run_with_number_filter_not_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_not_equals(self, mock_graph_runtime_state, graph_init_params): """Test number filter with not equals condition.""" config = { "title": "Test", @@ -503,9 +485,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -514,7 +495,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [1, 3, 7, 9] - def test_run_with_number_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_order_ascending(self, mock_graph_runtime_state, graph_init_params): """Test number ordering in ascending order.""" config = { "title": "Test", @@ -532,9 +513,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index a3afd1ed5c..b0f0fd428b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -1,10 +1,10 @@ import uuid from typing import NamedTuple from unittest import mock +from unittest.mock import MagicMock import httpx import pytest -from sqlalchemy import Engine from core.helper import ssrf_proxy from core.tools import signature @@ -44,7 +44,6 @@ class TestFileSaverImpl: ) mock_tool_file.id = _gen_id() mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) - mocked_engine = mock.MagicMock(spec=Engine) mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) @@ -53,11 +52,12 @@ class TestFileSaverImpl: # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) mocked_sign_file.return_value = mock_signed_url + http_client = MagicMock() storage_file_manager = FileSaverImpl( user_id=user_id, tenant_id=tenant_id, - engine_factory=mocked_engine, + http_client=http_client, ) file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type) @@ -87,16 +87,18 @@ class TestFileSaverImpl: status_code=401, request=mock_request, ) + http_client = MagicMock() + http_client.get.return_value = mock_response + file_saver = FileSaverImpl( user_id=_gen_id(), tenant_id=_gen_id(), + http_client=http_client, ) - mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) - monkeypatch.setattr(ssrf_proxy, "get", mock_get) with pytest.raises(httpx.HTTPStatusError) as exc: file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - mock_get.assert_called_once_with(_TEST_URL) + http_client.get.assert_called_once_with(_TEST_URL) assert exc.value.response.status_code == 401 def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch): @@ -112,8 +114,10 @@ class TestFileSaverImpl: headers={"Content-Type": mime_type}, request=mock_request, ) + http_client = MagicMock() + http_client.get.return_value = mock_response - file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id) + file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client) mock_tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 18ec3c0dc4..d56035b6bc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,7 +5,7 @@ from unittest import mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration @@ -39,8 +39,8 @@ from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment -from models.enums import UserFrom from models.provider import ProviderType +from tests.workflow_test_utils import build_test_graph_init_params class MockTokenBufferMemory: @@ -76,11 +76,11 @@ def llm_node_data() -> LLMNodeData: @pytest.fixture def graph_init_params() -> GraphInitParams: - return GraphInitParams( - tenant_id="1", - app_id="1", + return build_test_graph_init_params( workflow_id="1", graph_config={}, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, @@ -111,6 +111,7 @@ def llm_node( "id": "1", "data": llm_node_data.model_dump(), } + http_client = mock.MagicMock() node = LLMNode( id="1", config=node_config, @@ -120,6 +121,7 @@ def llm_node( model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, + http_client=http_client, ) return node @@ -632,6 +634,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat "id": "1", "data": llm_node_data.model_dump(), } + http_client = mock.MagicMock() node = LLMNode( id="1", config=node_config, @@ -641,6 +644,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, + http_client=http_client, ) return node, mock_file_saver diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 7edc56c3dc..6831626f58 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -1,14 +1,14 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -from dify_graph.graph_engine.entities.graph import Graph -from dify_graph.graph_engine.entities.graph_init_params import GraphInitParams -from dify_graph.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from dify_graph.graph import Graph from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode -from models.workflow import WorkflowType +from dify_graph.runtime import GraphRuntimeState +from tests.workflow_test_utils import build_test_graph_init_params class TestTemplateTransformNode: @@ -24,21 +24,20 @@ class TestTemplateTransformNode: @pytest.fixture def mock_graph(self): - """Create a mock Graph.""" + """Create a mock Graph (kept for backward compat in other tests).""" return MagicMock(spec=Graph) @pytest.fixture def graph_init_params(self): """Create a mock GraphInitParams.""" - return GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", - workflow_type=WorkflowType.WORKFLOW, + return build_test_graph_init_params( workflow_id="test_workflow", graph_config={}, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", - user_from="test", - invoke_from="test", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, call_depth=0, ) @@ -55,14 +54,15 @@ class TestTemplateTransformNode: "template": "Hello {{ name }}, you are {{ age }} years old!", } - def test_node_initialization(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test that TemplateTransformNode initializes correctly.""" + mock_renderer = MagicMock() node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) assert node.node_type == NodeType.TEMPLATE_TRANSFORM @@ -70,31 +70,33 @@ class TestTemplateTransformNode: assert len(node._node_data.variables) == 2 assert node._node_data.template == "Hello {{ name }}, you are {{ age }} years old!" - def test_get_title(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_title method.""" + mock_renderer = MagicMock() node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) assert node._get_title() == "Template Transform" - def test_get_description(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_description method.""" + mock_renderer = MagicMock() node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) assert node._get_description() == "Transform data using template" - def test_get_error_strategy(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_get_error_strategy(self, mock_graph_runtime_state, graph_init_params): """Test _get_error_strategy method.""" node_data = { "title": "Test", @@ -103,12 +105,13 @@ class TestTemplateTransformNode: "error_strategy": "fail-branch", } + mock_renderer = MagicMock() node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) assert node._get_error_strategy() == ErrorStrategy.FAIL_BRANCH @@ -127,14 +130,8 @@ class TestTemplateTransformNode: """Test version class method.""" assert TemplateTransformNode.version() == "1" - @patch( - "dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_simple_template( - self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params - ): - """Test _run with simple template transformation.""" + def test_run_simple_template(self, basic_node_data, mock_graph_runtime_state, graph_init_params): + """Test _run with simple template transformation using injected renderer.""" # Setup mock variable pool mock_name_value = MagicMock() mock_name_value.to_object.return_value = "Alice" @@ -147,15 +144,16 @@ class TestTemplateTransformNode: } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - # Setup mock executor - mock_execute.return_value = "Hello Alice, you are 30 years old!" + # Setup mock renderer + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!" node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -165,11 +163,7 @@ class TestTemplateTransformNode: assert result.inputs["name"] == "Alice" assert result.inputs["age"] == 30 - @patch( - "dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_none_values(self, mock_graph_runtime_state, graph_init_params): """Test _run with None variable values.""" node_data = { "title": "Test", @@ -178,14 +172,16 @@ class TestTemplateTransformNode: } mock_graph_runtime_state.variable_pool.get.return_value = None - mock_execute.return_value = "Value: " + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Value: " node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -193,23 +189,19 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.inputs["value"] is None - @patch( - "dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_code_execution_error( - self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params - ): - """Test _run when code execution fails.""" + def test_run_with_render_error(self, basic_node_data, mock_graph_runtime_state, graph_init_params): + """Test _run when template rendering fails.""" mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() - mock_execute.side_effect = TemplateRenderError("Template syntax error") + + mock_renderer = MagicMock() + mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error") node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -217,23 +209,19 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Template syntax error" in result.error - @patch( - "dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_output_length_exceeds_limit( - self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params - ): + def test_run_output_length_exceeds_limit(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _run when output exceeds maximum length.""" mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() - mock_execute.return_value = "This is a very long output that exceeds the limit" + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit" node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, max_output_length=10, ) @@ -242,13 +230,7 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Output length exceeds" in result.error - @patch( - "dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_complex_jinja2_template( - self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params - ): + def test_run_with_complex_jinja2_template(self, mock_graph_runtime_state, graph_init_params): """Test _run with complex Jinja2 template including loops and conditions.""" node_data = { "title": "Complex Template", @@ -272,14 +254,16 @@ class TestTemplateTransformNode: ("sys", "show_total"): mock_show_total, } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - mock_execute.return_value = "apple, banana, orange (Total: 3)" + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)" node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -307,11 +291,7 @@ class TestTemplateTransformNode: assert mapping["node_123.var1"] == ["sys", "input1"] assert mapping["node_123.var2"] == ["sys", "input2"] - @patch( - "dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_empty_variables(self, mock_graph_runtime_state, graph_init_params): """Test _run with no variables (static template).""" node_data = { "title": "Static Template", @@ -319,14 +299,15 @@ class TestTemplateTransformNode: "template": "This is a static message.", } - mock_execute.return_value = "This is a static message." + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "This is a static message." node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -335,11 +316,7 @@ class TestTemplateTransformNode: assert result.outputs["output"] == "This is a static message." assert result.inputs == {} - @patch( - "dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_numeric_values(self, mock_graph_runtime_state, graph_init_params): """Test _run with numeric variable values.""" node_data = { "title": "Numeric Template", @@ -360,14 +337,16 @@ class TestTemplateTransformNode: ("sys", "quantity"): mock_quantity, } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - mock_execute.return_value = "Total: $31.5" + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Total: $31.5" node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -375,11 +354,7 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["output"] == "Total: $31.5" - @patch( - "dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_dict_values(self, mock_graph_runtime_state, graph_init_params): """Test _run with dictionary variable values.""" node_data = { "title": "Dict Template", @@ -391,14 +366,16 @@ class TestTemplateTransformNode: mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"} mock_graph_runtime_state.variable_pool.get.return_value = mock_user - mock_execute.return_value = "Name: John Doe, Email: john@example.com" + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com" node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -407,11 +384,7 @@ class TestTemplateTransformNode: assert "John Doe" in result.outputs["output"] assert "john@example.com" in result.outputs["output"] - @patch( - "dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_list_values(self, mock_graph_runtime_state, graph_init_params): """Test _run with list variable values.""" node_data = { "title": "List Template", @@ -423,14 +396,16 @@ class TestTemplateTransformNode: mock_tags.to_object.return_value = ["python", "ai", "workflow"] mock_graph_runtime_state.variable_pool.get.return_value = mock_tags - mock_execute.return_value = "Tags: #python #ai #workflow " + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Tags: #python #ai #workflow " node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index d5e0a3ce2e..44abf430c0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -2,12 +2,14 @@ from collections.abc import Mapping import pytest +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams from dify_graph.enums import NodeType from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params class _SampleNodeData(BaseNodeData): @@ -26,15 +28,10 @@ class _SampleNode(Node[_SampleNodeData]): def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]: - init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", + init_params = build_test_graph_init_params( graph_config=graph_config, - user_id="user", user_from="account", invoke_from="debugger", - call_depth=0, ) runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), @@ -56,6 +53,36 @@ def test_node_hydrates_data_during_initialization(): assert node.node_data.foo == "bar" assert node.title == "Sample" + dify_ctx = node.require_dify_context() + assert dify_ctx.user_from == "account" + assert dify_ctx.invoke_from == "debugger" + + +def test_node_accepts_invoke_from_enum(): + graph_config: dict[str, object] = {} + init_params = build_test_graph_init_params( + graph_config=graph_config, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + start_at=0.0, + ) + + node = _SampleNode( + id="node-1", + config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}}, + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + dify_ctx = node.require_dify_context() + assert dify_ctx.user_from == UserFrom.ACCOUNT + assert dify_ctx.invoke_from == InvokeFrom.DEBUGGER + assert node.get_run_context_value("missing") is None + with pytest.raises(ValueError): + node.require_run_context_value("missing") def test_missing_generic_argument_raises_type_error(): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index b6169cd735..13275d4be6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -5,7 +5,7 @@ import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod @@ -16,20 +16,21 @@ from dify_graph.nodes.document_extractor.node import ( _extract_text_from_excel, _extract_text_from_pdf, _extract_text_from_plain_text, + _normalize_docx_zip, ) from dify_graph.variables import ArrayFileSegment from dify_graph.variables.segments import ArrayStringSegment from dify_graph.variables.variables import StringVariable -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def graph_init_params() -> GraphInitParams: - return GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", + return build_test_graph_init_params( workflow_id="test_workflow", graph_config={}, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -44,11 +45,13 @@ def document_extractor_node(graph_init_params): variable_selector=["node_id", "variable_name"], ) node_config = {"id": "test_node_id", "data": node_data.model_dump()} + http_client = Mock() node = DocumentExtractorNode( id="test_node_id", config=node_config, graph_init_params=graph_init_params, graph_runtime_state=Mock(), + http_client=http_client, ) return node @@ -84,6 +87,38 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s assert "is not an ArrayFileSegment" in result.error +def test_run_empty_file_list_returns_succeeded(document_extractor_node, mock_graph_runtime_state): + """Empty file list should return SUCCEEDED with empty documents and ArrayStringSegment([]).""" + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + + # Provide an actual ArrayFileSegment with an empty list + mock_graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(value=[]) + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error + assert result.process_data.get("documents") == [] + assert result.outputs["text"] == ArrayStringSegment(value=[]) + + +def test_run_none_only_file_list_returns_succeeded(document_extractor_node, mock_graph_runtime_state): + """A file list containing only None (e.g., [None]) should be filtered to [] and succeed.""" + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + + # Use a Mock to bypass type validation for None entries in the list + afs = Mock(spec=ArrayFileSegment) + afs.value = [None] + mock_graph_runtime_state.variable_pool.get.return_value = afs + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error + assert result.process_data.get("documents") == [] + assert result.outputs["text"] == ArrayStringSegment(value=[]) + + @pytest.mark.parametrize( ("mime_type", "file_content", "expected_text", "transfer_method", "extension"), [ @@ -142,12 +177,13 @@ def test_run_extract_text( mock_graph_runtime_state.variable_pool.get.return_value = mock_array_file_segment mock_download = Mock(return_value=file_content) - mock_ssrf_proxy_get = Mock() - mock_ssrf_proxy_get.return_value.content = file_content - mock_ssrf_proxy_get.return_value.raise_for_status = Mock() + + mock_response = Mock() + mock_response.content = file_content + mock_response.raise_for_status = Mock() + document_extractor_node._http_client.get = Mock(return_value=mock_response) monkeypatch.setattr("dify_graph.file.file_manager.download", mock_download) - monkeypatch.setattr("core.helper.ssrf_proxy.get", mock_ssrf_proxy_get) if mime_type == "application/pdf": mock_pdf_extract = Mock(return_value=expected_text[0]) @@ -164,7 +200,7 @@ def test_run_extract_text( assert result.outputs["text"] == ArrayStringSegment(value=expected_text) if transfer_method == FileTransferMethod.REMOTE_URL: - mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt") + document_extractor_node._http_client.get.assert_called_once_with("https://example.com/file.txt") elif transfer_method == FileTransferMethod.LOCAL_FILE: mock_download.assert_called_once_with(mock_file) @@ -382,3 +418,58 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file): expected_manual = "| 1.0 | 1.1 |\n| --- | --- |\n| Test | Test |\n\n" assert expected_manual == result + + +def _make_docx_zip(use_backslash: bool) -> bytes: + """Helper to build a minimal in-memory DOCX zip. + + When use_backslash=True the ZIP entry names use backslash separators + (as produced by Evernote on Windows), otherwise forward slashes are used. + """ + import zipfile + + sep = "\\" if use_backslash else "/" + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr("[Content_Types].xml", b"
{userProfile.name}
diff --git a/web/app/account/(commonLayout)/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx
index 8ea29e8e45..07b685b8c5 100644
--- a/web/app/account/(commonLayout)/avatar.tsx
+++ b/web/app/account/(commonLayout)/avatar.tsx
@@ -7,12 +7,11 @@ import { useRouter } from 'next/navigation'
import { Fragment } from 'react'
import { useTranslation } from 'react-i18next'
import { resetUser } from '@/app/components/base/amplitude/utils'
-import Avatar from '@/app/components/base/avatar'
+import { Avatar } from '@/app/components/base/avatar'
import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general'
import PremiumBadge from '@/app/components/base/premium-badge'
-import { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context'
-import { useLogout } from '@/service/use-common'
+import { useLogout, useUserProfile } from '@/service/use-common'
export type IAppSelector = {
isMobile: boolean
@@ -21,10 +20,15 @@ export type IAppSelector = {
export default function AppSelector() {
const router = useRouter()
const { t } = useTranslation()
- const { userProfile } = useAppContext()
+ const { data: userProfileResp } = useUserProfile()
+ const userProfile = userProfileResp?.profile
const { isEducationAccount } = useProviderContext()
const { mutateAsync: logout } = useLogout()
+
+ if (!userProfile)
+ return null
+
const handleLogout = async () => {
await logout()
@@ -50,7 +54,7 @@ export default function AppSelector() {
${open && 'bg-components-panel-bg-blur'}
`}
>
-