From eef13853b22d9fc76040f391c3d0ca422313cc45 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 25 Mar 2026 10:21:57 +0800 Subject: [PATCH 1/8] fix(api): StreamsBroadcastChannel start reading messages from the end (#34030) The current frontend implementation closes the connection once `workflow_paused` SSE event is received and establish a new connection to subscribe new events. The implementation of `StreamsBroadcastChannel` sets initial `_last_id` to `0-0`, consumes streams from start and send `workflow_paused` event created before pauses to frontend, causing excessive connections being established. This PR fixes the issue by setting initial id to `$`, which means only new messages are received by the subscription. --- api/libs/broadcast_channel/channel.py | 3 +- .../redis/streams_channel.py | 5 +- .../redis/test_streams_channel.py | 227 +++++++++++++++++ .../redis/test_streams_channel_unit_tests.py | 228 +++++++++++++++++- 4 files changed, 451 insertions(+), 12 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py diff --git a/api/libs/broadcast_channel/channel.py b/api/libs/broadcast_channel/channel.py index d4cb3e9971..8eeac37232 100644 --- a/api/libs/broadcast_channel/channel.py +++ b/api/libs/broadcast_channel/channel.py @@ -125,7 +125,8 @@ class BroadcastChannel(Protocol): a specific topic, all subscription should receive the published message. There are no restriction for the persistence of messages. Once a subscription is created, it - should receive all subsequent messages published. + should receive all subsequent messages published. However, a subscription should not receive + any message published before the subscription is established. `BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads. """ diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index d6ec5504ca..aaeaf76f7b 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -64,7 +64,10 @@ class _StreamsSubscription(Subscription): self._client = client self._key = key self._closed = threading.Event() - self._last_id = "0-0" + # Setting initial last id to `$` to signal redis that we only want new messages. + # + # ref: https://redis.io/docs/latest/commands/xread/#the-special--id + self._last_id = "$" self._queue: queue.Queue[object] = queue.Queue() self._start_lock = threading.Lock() self._listener: threading.Thread | None = None diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py new file mode 100644 index 0000000000..a79208f649 --- /dev/null +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py @@ -0,0 +1,227 @@ +""" +Integration tests for Redis Streams broadcast channel implementation using TestContainers. + +This suite focuses on the semantics that differ from Redis Pub/Sub: +- Every active subscription should receive each newly published message. +- Each subscription should only observe messages published after its listener starts. +""" + +import threading +import time +import uuid +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +import redis +from testcontainers.redis import RedisContainer + +from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic +from libs.broadcast_channel.exc import SubscriptionClosedError +from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel + + +class TestRedisStreamsBroadcastChannelIntegration: + """Integration tests for Redis Streams broadcast channel with a real Redis instance.""" + + @pytest.fixture(scope="class") + def redis_container(self) -> Iterator[RedisContainer]: + """Create a Redis container for integration testing.""" + with RedisContainer(image="redis:6-alpine") as container: + yield container + + @pytest.fixture(scope="class") + def redis_client(self, redis_container: RedisContainer) -> redis.Redis: + """Create a Redis client connected to the test container.""" + host = redis_container.get_container_host_ip() + port = redis_container.get_exposed_port(6379) + return redis.Redis(host=host, port=port, decode_responses=False) + + @pytest.fixture + def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel: + """Create a StreamsBroadcastChannel instance with a real Redis client.""" + return StreamsBroadcastChannel(redis_client) + + @classmethod + def _get_test_topic_name(cls) -> str: + return f"test_streams_topic_{uuid.uuid4()}" + + @staticmethod + def _start_subscription(subscription: Subscription) -> None: + """Start the background listener and confirm the subscription queue is empty.""" + assert subscription.receive(timeout=0.05) is None + + @staticmethod + def _receive_message(subscription: Subscription, *, timeout_seconds: float = 2.0) -> bytes: + """Poll until a message is received or the timeout expires.""" + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + message = subscription.receive(timeout=0.1) + if message is not None: + return message + pytest.fail("Timed out waiting for a message") + + def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel: BroadcastChannel) -> None: + """Closing an active subscription should terminate the iterator cleanly.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + consuming_event = threading.Event() + + def consume() -> list[bytes]: + messages: list[bytes] = [] + consuming_event.set() + for message in subscription: + messages.append(message) + return messages + + with ThreadPoolExecutor(max_workers=1) as executor: + consumer_future = executor.submit(consume) + assert consuming_event.wait(timeout=1.0) + subscription.close() + assert consumer_future.result(timeout=2.0) == [] + + def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel) -> None: + """A producer should publish a message that a live subscription can consume.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + producer = topic.as_producer() + subscription = topic.subscribe() + message = b"hello streams" + + try: + self._start_subscription(subscription) + producer.publish(message) + + assert self._receive_message(subscription) == message + assert subscription.receive(timeout=0.1) is None + finally: + subscription.close() + + def test_multiple_subscriptions_each_receive_each_new_message(self, broadcast_channel: BroadcastChannel) -> None: + """Each active subscription should receive the same newly published message.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscriptions = [topic.subscribe() for _ in range(3)] + new_message = b"message-visible-to-every-subscriber" + + try: + for subscription in subscriptions: + self._start_subscription(subscription) + + topic.publish(new_message) + + for subscription in subscriptions: + assert self._receive_message(subscription) == new_message + assert subscription.receive(timeout=0.1) is None + finally: + for subscription in subscriptions: + subscription.close() + + def test_each_subscription_only_receives_messages_published_after_it_starts( + self, + broadcast_channel: BroadcastChannel, + ) -> None: + """A late subscription should not replay messages that existed before its listener started.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + first_subscription = topic.subscribe() + second_subscription = topic.subscribe() + message_before_any_subscription = b"before-any-subscription" + message_after_first_subscription = b"after-first-subscription" + message_after_second_subscription = b"after-second-subscription" + + try: + topic.publish(message_before_any_subscription) + + self._start_subscription(first_subscription) + topic.publish(message_after_first_subscription) + + assert self._receive_message(first_subscription) == message_after_first_subscription + assert first_subscription.receive(timeout=0.1) is None + + self._start_subscription(second_subscription) + topic.publish(message_after_second_subscription) + + assert self._receive_message(first_subscription) == message_after_second_subscription + assert self._receive_message(second_subscription) == message_after_second_subscription + assert first_subscription.receive(timeout=0.1) is None + assert second_subscription.receive(timeout=0.1) is None + finally: + first_subscription.close() + second_subscription.close() + + def test_topic_isolation(self, broadcast_channel: BroadcastChannel) -> None: + """Messages from different topics should remain isolated.""" + topic1 = broadcast_channel.topic(self._get_test_topic_name()) + topic2 = broadcast_channel.topic(self._get_test_topic_name()) + message1 = b"message-for-topic-1" + message2 = b"message-for-topic-2" + + def consume_single_message(topic: Topic) -> bytes: + subscription = topic.subscribe() + try: + self._start_subscription(subscription) + return self._receive_message(subscription) + finally: + subscription.close() + + with ThreadPoolExecutor(max_workers=3) as executor: + consumer1_future = executor.submit(consume_single_message, topic1) + consumer2_future = executor.submit(consume_single_message, topic2) + time.sleep(0.1) + topic1.publish(message1) + topic2.publish(message2) + + assert consumer1_future.result(timeout=5.0) == message1 + assert consumer2_future.result(timeout=5.0) == message2 + + def test_concurrent_producers_publish_all_messages(self, broadcast_channel: BroadcastChannel) -> None: + """Concurrent producers should not lose messages for a live subscription.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + producer_count = 4 + messages_per_producer = 4 + expected_total = producer_count * messages_per_producer + consumer_ready = threading.Event() + + def produce_messages(producer_idx: int) -> set[bytes]: + producer = topic.as_producer() + produced: set[bytes] = set() + for message_idx in range(messages_per_producer): + payload = f"producer-{producer_idx}-message-{message_idx}".encode() + produced.add(payload) + producer.publish(payload) + time.sleep(0.001) + return produced + + def consume_messages() -> set[bytes]: + received: set[bytes] = set() + try: + self._start_subscription(subscription) + consumer_ready.set() + while len(received) < expected_total: + message = subscription.receive(timeout=0.2) + if message is not None: + received.add(message) + return received + finally: + subscription.close() + + with ThreadPoolExecutor(max_workers=producer_count + 1) as executor: + consumer_future = executor.submit(consume_messages) + assert consumer_ready.wait(timeout=2.0) + + producer_futures = [executor.submit(produce_messages, idx) for idx in range(producer_count)] + expected_messages: set[bytes] = set() + for future in as_completed(producer_futures, timeout=10.0): + expected_messages.update(future.result()) + + assert consumer_future.result(timeout=10.0) == expected_messages + + def test_receive_raises_subscription_closed_after_close(self, broadcast_channel: BroadcastChannel) -> None: + """Calling receive on a closed subscription should raise SubscriptionClosedError.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + + self._start_subscription(subscription) + subscription.close() + + with pytest.raises(SubscriptionClosedError): + subscription.receive(timeout=0.1) diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py index 248aa0b145..bf548f69cf 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py @@ -1,7 +1,11 @@ +import threading import time +from dataclasses import dataclass +from typing import cast import pytest +from libs.broadcast_channel.exc import SubscriptionClosedError from libs.broadcast_channel.redis.streams_channel import ( StreamsBroadcastChannel, StreamsTopic, @@ -22,6 +26,7 @@ class FakeStreamsRedis: self._store: dict[str, list[tuple[str, dict]]] = {} self._next_id: dict[str, int] = {} self._expire_calls: dict[str, int] = {} + self._dollar_snapshots: dict[str, int] = {} # Publisher API def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: @@ -47,7 +52,9 @@ class FakeStreamsRedis: # Find position strictly greater than last_id start_idx = 0 - if last_id != "0-0": + if last_id == "$": + start_idx = self._dollar_snapshots.setdefault(key, len(entries)) + elif last_id != "0-0": for i, (eid, _f) in enumerate(entries): if eid == last_id: start_idx = i + 1 @@ -63,6 +70,55 @@ class FakeStreamsRedis: return [(key, batch)] +class FailExpireRedis(FakeStreamsRedis): + def expire(self, key: str, seconds: int) -> None: + raise RuntimeError("expire failed") + + +class BlockingRedis: + def __init__(self) -> None: + self._release = threading.Event() + + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + self._release.wait(timeout=block / 1000.0 if block else None) + return [] + + def release(self) -> None: + self._release.set() + + +@dataclass(frozen=True) +class ListenPayloadCase: + name: str + fields: object + expected_messages: list[bytes] + + +def build_listen_payload_cases() -> list[ListenPayloadCase]: + return [ + ListenPayloadCase( + name="string_payload_is_encoded", + fields={b"data": "hello"}, + expected_messages=[b"hello"], + ), + ListenPayloadCase( + name="bytearray_payload_is_converted", + fields={b"data": bytearray(b"world")}, + expected_messages=[b"world"], + ), + ListenPayloadCase( + name="non_dict_fields_are_ignored", + fields=[("data", b"ignored")], + expected_messages=[], + ), + ListenPayloadCase( + name="missing_payload_is_ignored", + fields={b"other": b"ignored"}, + expected_messages=[], + ), + ] + + @pytest.fixture def fake_redis() -> FakeStreamsRedis: return FakeStreamsRedis() @@ -94,21 +150,37 @@ class TestStreamsBroadcastChannel: # Expire called after publish assert fake_redis._expire_calls.get("stream:beta", 0) >= 1 + def test_topic_exposes_self_as_producer_and_subscriber(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("producer-subscriber") + + assert topic.as_producer() is topic + assert topic.as_subscriber() is topic + + def test_publish_logs_warning_when_expire_fails(self, caplog: pytest.LogCaptureFixture): + channel = StreamsBroadcastChannel(FailExpireRedis(), retention_seconds=60) + topic = channel.topic("expire-warning") + + topic.publish(b"payload") + + assert "Failed to set expire for stream key" in caplog.text + class TestStreamsSubscription: - def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel): + def test_subscribe_only_receives_messages_published_after_subscription_starts( + self, + streams_channel: StreamsBroadcastChannel, + ): topic = streams_channel.topic("gamma") - # Pre-publish events before subscribing (late subscriber) - topic.publish(b"e1") - topic.publish(b"e2") + topic.publish(b"before-subscribe") sub = topic.subscribe() assert isinstance(sub, _StreamsSubscription) received: list[bytes] = [] with sub: - # Give listener thread a moment to xread - time.sleep(0.05) + assert sub.receive(timeout=0.05) is None + topic.publish(b"after-subscribe-1") + topic.publish(b"after-subscribe-2") # Drain using receive() to avoid indefinite iteration in tests for _ in range(5): msg = sub.receive(timeout=0.1) @@ -116,7 +188,7 @@ class TestStreamsSubscription: break received.append(msg) - assert received == [b"e1", b"e2"] + assert received == [b"after-subscribe-1", b"after-subscribe-2"] def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel): topic = streams_channel.topic("delta") @@ -132,8 +204,6 @@ class TestStreamsSubscription: # Listener running; now close and ensure no crash sub.close() # After close, receive should raise SubscriptionClosedError - from libs.broadcast_channel.exc import SubscriptionClosedError - with pytest.raises(SubscriptionClosedError): sub.receive() @@ -143,3 +213,141 @@ class TestStreamsSubscription: topic.publish(b"payload") # No expire recorded when retention is disabled assert fake_redis._expire_calls.get("stream:zeta") is None + + @pytest.mark.parametrize( + ("case"), + build_listen_payload_cases(), + ids=lambda case: cast(ListenPayloadCase, case).name, + ) + def test_listener_normalizes_supported_payloads_and_ignores_unsupported_shapes(self, case: ListenPayloadCase): + class OneShotRedis: + def __init__(self, fields: object) -> None: + self._fields = fields + self._calls = 0 + + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + self._calls += 1 + if self._calls == 1: + key = next(iter(streams)) + return [(key, [("1-0", self._fields)])] + subscription._closed.set() + return [] + + subscription = _StreamsSubscription(OneShotRedis(case.fields), "stream:payload-shape") + subscription._listen() + + received: list[bytes] = [] + while not subscription._queue.empty(): + item = subscription._queue.get_nowait() + if item is subscription._SENTINEL: + break + received.append(bytes(item)) + + assert received == case.expected_messages + assert subscription._last_id == "1-0" + + def test_iterator_yields_messages_until_subscription_is_closed(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("iter") + subscription = topic.subscribe() + iterator = iter(subscription) + + def publish_later() -> None: + time.sleep(0.05) + topic.publish(b"iter-message") + + publisher = threading.Thread(target=publish_later, daemon=True) + publisher.start() + + assert next(iterator) == b"iter-message" + + subscription.close() + publisher.join(timeout=1) + with pytest.raises(StopIteration): + next(iterator) + + def test_receive_with_none_timeout_blocks_until_message_arrives(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("blocking") + subscription = topic.subscribe() + + def publish_later() -> None: + time.sleep(0.05) + topic.publish(b"blocking-message") + + publisher = threading.Thread(target=publish_later, daemon=True) + publisher.start() + + try: + assert subscription.receive(timeout=None) == b"blocking-message" + finally: + subscription.close() + publisher.join(timeout=1) + + def test_receive_raises_when_queue_contains_close_sentinel(self): + subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:sentinel") + subscription._listener = threading.current_thread() + subscription._queue.put_nowait(subscription._SENTINEL) + + with pytest.raises(SubscriptionClosedError): + subscription.receive(timeout=0.01) + + def test_close_before_listener_starts_is_a_noop(self): + subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:not-started") + + subscription.close() + + assert subscription._listener is None + with pytest.raises(SubscriptionClosedError): + subscription.receive(timeout=0.01) + + def test_start_if_needed_returns_immediately_for_closed_subscription(self): + subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:already-closed") + subscription._closed.set() + + subscription._start_if_needed() + + assert subscription._listener is None + + def test_iterator_skips_none_results_and_keeps_polling(self): + subscription = _StreamsSubscription(FakeStreamsRedis(), "stream:iterator-none") + items = iter([None, b"event"]) + + subscription._start_if_needed = lambda: None # type: ignore[method-assign] + + def fake_receive(timeout: float | None = 0.1) -> bytes | None: + value = next(items) + if value is not None: + subscription._closed.set() + return value + + subscription.receive = fake_receive # type: ignore[method-assign] + + assert next(iter(subscription)) == b"event" + + def test_close_logs_warning_when_listener_does_not_stop_in_time( + self, + caplog: pytest.LogCaptureFixture, + ): + blocking_redis = BlockingRedis() + subscription = _StreamsSubscription(blocking_redis, "stream:slow-close") + + subscription._start_if_needed() + listener = subscription._listener + assert listener is not None + + original_join = listener.join + original_is_alive = listener.is_alive + + def delayed_join(timeout: float | None = None) -> None: + original_join(0.01) + + listener.join = delayed_join # type: ignore[method-assign] + listener.is_alive = lambda: True # type: ignore[method-assign] + + try: + subscription.close() + assert "did not stop within timeout" in caplog.text + finally: + listener.join = original_join # type: ignore[method-assign] + listener.is_alive = original_is_alive # type: ignore[method-assign] + blocking_redis.release() + original_join(timeout=1) From c6c271539572ae60e8f0dbbdadd5a638c36a84e0 Mon Sep 17 00:00:00 2001 From: lif <1835304752@qq.com> Date: Wed, 25 Mar 2026 11:14:12 +0800 Subject: [PATCH 2/8] fix(workflow): clear loop/iteration metadata when pasting node outside container (#29983) Co-authored-by: hjlarry --- .../workflow/hooks/use-nodes-interactions.ts | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index cd35d2310f..8de86edecb 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -1822,6 +1822,8 @@ export const useNodesInteractions = () => { else { // single node paste const selectedNode = nodes.find(node => node.selected) + let pastedToNestedBlock = false + if (selectedNode) { const commonNestedDisallowPasteNodes = [ // end node only can be placed outermost layer @@ -1849,10 +1851,24 @@ export const useNodesInteractions = () => { } // set position base on parent node newNode.position = getNestedNodePosition(newNode, selectedNode) + // update parent children array like native add parentChildrenToAppend.push({ parentId: selectedNode.id, childId: newNode.id, childType: newNode.data.type }) + + pastedToNestedBlock = true } } + + // Clear loop/iteration metadata when pasting outside nested blocks (fixes #29835) + // This ensures nodes copied from inside Loop/Iteration are properly independent + // when pasted outside + if (!pastedToNestedBlock) { + newNode.data.isInLoop = false + newNode.data.loop_id = undefined + newNode.data.isInIteration = false + newNode.data.iteration_id = undefined + newNode.parentId = undefined + } } idMapping[nodeToPaste.id] = newNode.id From cb2888520532348e43a60c32d327bf4af324e417 Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Wed, 25 Mar 2026 11:35:20 +0800 Subject: [PATCH 3/8] fix: update docs path (#34052) --- web/context/i18n.spec.ts | 4 +- web/hooks/use-api-access-url.ts | 2 +- web/types/doc-paths.ts | 290 ++++++++++++++++++-------------- 3 files changed, 168 insertions(+), 128 deletions(-) diff --git a/web/context/i18n.spec.ts b/web/context/i18n.spec.ts index 616f3bfced..9ebbda825e 100644 --- a/web/context/i18n.spec.ts +++ b/web/context/i18n.spec.ts @@ -184,8 +184,8 @@ describe('useDocLink', () => { vi.mocked(getDocLanguage).mockReturnValue('ja') const { result } = renderHook(() => useDocLink()) - const url = result.current('/api-reference/application/get-application-basic-information') - expect(url).toBe(`${defaultDocBaseUrl}/api-reference/アプリケーション情報/アプリケーションの基本情報を取得`) + const url = result.current('/api-reference/applications/get-app-info') + expect(url).toBe(`${defaultDocBaseUrl}/api-reference/アプリケーション設定/アプリケーションの基本情報を取得`) }) it('should not translate API reference path for English locale', () => { diff --git a/web/hooks/use-api-access-url.ts b/web/hooks/use-api-access-url.ts index 98576e66db..7f63b7754e 100644 --- a/web/hooks/use-api-access-url.ts +++ b/web/hooks/use-api-access-url.ts @@ -3,5 +3,5 @@ import { useDocLink } from '@/context/i18n' export const useDatasetApiAccessUrl = () => { const docLink = useDocLink() - return docLink('/api-reference/datasets/get-knowledge-base-list') + return docLink('/api-reference/knowledge-bases/list-knowledge-bases') } diff --git a/web/types/doc-paths.ts b/web/types/doc-paths.ts index 8f95249354..9cbad79a2e 100644 --- a/web/types/doc-paths.ts +++ b/web/types/doc-paths.ts @@ -2,7 +2,7 @@ // DON NOT EDIT IT MANUALLY // // Generated from: https://raw.githubusercontent.com/langgenius/dify-docs/refs/heads/main/docs.json -// Generated at: 2026-01-30T09:14:29.304Z +// Generated at: 2026-03-25T03:18:49.626Z // Language prefixes export type DocLanguage = 'en' | 'zh' | 'ja' @@ -61,6 +61,7 @@ export type UseDifyPath = | '/use-dify/nodes/code' | '/use-dify/nodes/doc-extractor' | '/use-dify/nodes/http-request' + | '/use-dify/nodes/human-input' | '/use-dify/nodes/ifelse' | '/use-dify/nodes/iteration' | '/use-dify/nodes/knowledge-retrieval' @@ -82,6 +83,7 @@ export type UseDifyPath = | '/use-dify/publish/README' | '/use-dify/publish/developing-with-apis' | '/use-dify/publish/publish-mcp' + | '/use-dify/publish/publish-to-marketplace' | '/use-dify/publish/webapp/chatflow-webapp' | '/use-dify/publish/webapp/embedding-in-websites' | '/use-dify/publish/webapp/web-app-access' @@ -92,6 +94,16 @@ export type UseDifyPath = | '/use-dify/tutorials/customer-service-bot' | '/use-dify/tutorials/simple-chatbot' | '/use-dify/tutorials/twitter-chatflow' + | '/use-dify/tutorials/workflow-101/lesson-01' + | '/use-dify/tutorials/workflow-101/lesson-02' + | '/use-dify/tutorials/workflow-101/lesson-03' + | '/use-dify/tutorials/workflow-101/lesson-04' + | '/use-dify/tutorials/workflow-101/lesson-05' + | '/use-dify/tutorials/workflow-101/lesson-06' + | '/use-dify/tutorials/workflow-101/lesson-07' + | '/use-dify/tutorials/workflow-101/lesson-08' + | '/use-dify/tutorials/workflow-101/lesson-09' + | '/use-dify/tutorials/workflow-101/lesson-10' | '/use-dify/workspace/api-extension/api-extension' | '/use-dify/workspace/api-extension/cloudflare-worker' | '/use-dify/workspace/api-extension/external-data-tool-api-extension' @@ -167,72 +179,86 @@ export type DevelopPluginPath = // API Reference paths (English, use apiReferencePathTranslations for other languages) export type ApiReferencePath = + | '/api-reference/annotations/configure-annotation-reply' | '/api-reference/annotations/create-annotation' | '/api-reference/annotations/delete-annotation' - | '/api-reference/annotations/get-annotation-list' - | '/api-reference/annotations/initial-annotation-reply-settings' - | '/api-reference/annotations/query-initial-annotation-reply-settings-task-status' + | '/api-reference/annotations/get-annotation-reply-job-status' + | '/api-reference/annotations/list-annotations' | '/api-reference/annotations/update-annotation' - | '/api-reference/application/get-application-basic-information' - | '/api-reference/application/get-application-meta-information' - | '/api-reference/application/get-application-parameters-information' - | '/api-reference/application/get-application-webapp-settings' - | '/api-reference/chat/next-suggested-questions' - | '/api-reference/chat/send-chat-message' - | '/api-reference/chat/stop-chat-message-generation' - | '/api-reference/chatflow/next-suggested-questions' - | '/api-reference/chatflow/send-chat-message' - | '/api-reference/chatflow/stop-advanced-chat-message-generation' - | '/api-reference/chunks/add-chunks-to-a-document' + | '/api-reference/applications/get-app-info' + | '/api-reference/applications/get-app-meta' + | '/api-reference/applications/get-app-parameters' + | '/api-reference/applications/get-app-webapp-settings' + | '/api-reference/chats/get-next-suggested-questions' + | '/api-reference/chats/send-chat-message' + | '/api-reference/chats/stop-chat-message-generation' | '/api-reference/chunks/create-child-chunk' - | '/api-reference/chunks/delete-a-chunk-in-a-document' + | '/api-reference/chunks/create-chunks' | '/api-reference/chunks/delete-child-chunk' - | '/api-reference/chunks/get-a-chunk-details-in-a-document' - | '/api-reference/chunks/get-child-chunks' - | '/api-reference/chunks/get-chunks-from-a-document' - | '/api-reference/chunks/update-a-chunk-in-a-document' + | '/api-reference/chunks/delete-chunk' + | '/api-reference/chunks/get-chunk' + | '/api-reference/chunks/list-child-chunks' + | '/api-reference/chunks/list-chunks' | '/api-reference/chunks/update-child-chunk' - | '/api-reference/completion/create-completion-message' - | '/api-reference/completion/stop-generate' - | '/api-reference/conversations/conversation-rename' + | '/api-reference/chunks/update-chunk' + | '/api-reference/completions/send-completion-message' + | '/api-reference/completions/stop-completion-message-generation' | '/api-reference/conversations/delete-conversation' - | '/api-reference/conversations/get-conversation-history-messages' - | '/api-reference/conversations/get-conversation-variables' - | '/api-reference/conversations/get-conversations' - | '/api-reference/datasets/create-an-empty-knowledge-base' - | '/api-reference/datasets/delete-a-knowledge-base' - | '/api-reference/datasets/get-knowledge-base-details' - | '/api-reference/datasets/get-knowledge-base-list' - | '/api-reference/datasets/retrieve-chunks-from-a-knowledge-base-/-test-retrieval' - | '/api-reference/datasets/update-knowledge-base' - | '/api-reference/documents/create-a-document-from-a-file' - | '/api-reference/documents/create-a-document-from-text' - | '/api-reference/documents/delete-a-document' - | '/api-reference/documents/get-document-detail' - | '/api-reference/documents/get-document-embedding-status-(progress)' - | '/api-reference/documents/get-the-document-list-of-a-knowledge-base' - | '/api-reference/documents/update-a-document-with-a-file' - | '/api-reference/documents/update-a-document-with-text' - | '/api-reference/documents/update-document-status' - | '/api-reference/feedback/get-feedbacks-of-application' - | '/api-reference/feedback/message-feedback' - | '/api-reference/files/file-preview' - | '/api-reference/files/file-upload' - | '/api-reference/files/file-upload-for-workflow' - | '/api-reference/metadata-&-tags/bind-dataset-to-knowledge-base-type-tag' - | '/api-reference/metadata-&-tags/create-new-knowledge-base-type-tag' - | '/api-reference/metadata-&-tags/delete-knowledge-base-type-tag' - | '/api-reference/metadata-&-tags/get-knowledge-base-type-tags' - | '/api-reference/metadata-&-tags/modify-knowledge-base-type-tag-name' - | '/api-reference/metadata-&-tags/query-tags-bound-to-a-dataset' - | '/api-reference/metadata-&-tags/unbind-dataset-and-knowledge-base-type-tag' - | '/api-reference/models/get-available-embedding-models' - | '/api-reference/tts/speech-to-text' - | '/api-reference/tts/text-to-audio' - | '/api-reference/workflow-execution/execute-workflow' - | '/api-reference/workflow-execution/get-workflow-logs' - | '/api-reference/workflow-execution/get-workflow-run-detail' - | '/api-reference/workflow-execution/stop-workflow-task-generation' + | '/api-reference/conversations/list-conversation-messages' + | '/api-reference/conversations/list-conversation-variables' + | '/api-reference/conversations/list-conversations' + | '/api-reference/conversations/rename-conversation' + | '/api-reference/conversations/update-conversation-variable' + | '/api-reference/documents/create-document-by-file' + | '/api-reference/documents/create-document-by-text' + | '/api-reference/documents/delete-document' + | '/api-reference/documents/download-document' + | '/api-reference/documents/download-documents-as-zip' + | '/api-reference/documents/get-document' + | '/api-reference/documents/get-document-indexing-status' + | '/api-reference/documents/list-documents' + | '/api-reference/documents/update-document-by-file' + | '/api-reference/documents/update-document-by-text' + | '/api-reference/documents/update-document-status-in-batch' + | '/api-reference/end-users/get-end-user-info' + | '/api-reference/feedback/list-app-feedbacks' + | '/api-reference/feedback/submit-message-feedback' + | '/api-reference/files/download-file' + | '/api-reference/files/upload-file' + | '/api-reference/knowledge-bases/create-an-empty-knowledge-base' + | '/api-reference/knowledge-bases/delete-knowledge-base' + | '/api-reference/knowledge-bases/get-knowledge-base' + | '/api-reference/knowledge-bases/list-knowledge-bases' + | '/api-reference/knowledge-bases/retrieve-chunks-from-a-knowledge-base-/-test-retrieval' + | '/api-reference/knowledge-bases/update-knowledge-base' + | '/api-reference/knowledge-pipeline/list-datasource-plugins' + | '/api-reference/knowledge-pipeline/run-datasource-node' + | '/api-reference/knowledge-pipeline/run-pipeline' + | '/api-reference/knowledge-pipeline/upload-pipeline-file' + | '/api-reference/metadata/create-metadata-field' + | '/api-reference/metadata/delete-metadata-field' + | '/api-reference/metadata/get-built-in-metadata-fields' + | '/api-reference/metadata/list-metadata-fields' + | '/api-reference/metadata/update-built-in-metadata-field' + | '/api-reference/metadata/update-document-metadata-in-batch' + | '/api-reference/metadata/update-metadata-field' + | '/api-reference/models/get-available-models' + | '/api-reference/tags/create-knowledge-tag' + | '/api-reference/tags/create-tag-binding' + | '/api-reference/tags/delete-knowledge-tag' + | '/api-reference/tags/delete-tag-binding' + | '/api-reference/tags/get-knowledge-base-tags' + | '/api-reference/tags/list-knowledge-tags' + | '/api-reference/tags/update-knowledge-tag' + | '/api-reference/tts/convert-audio-to-text' + | '/api-reference/tts/convert-text-to-audio' + | '/api-reference/workflow-runs/get-workflow-run-detail' + | '/api-reference/workflow-runs/list-workflow-logs' + | '/api-reference/workflows/get-workflow-run-detail' + | '/api-reference/workflows/list-workflow-logs' + | '/api-reference/workflows/run-workflow' + | '/api-reference/workflows/run-workflow-by-id' + | '/api-reference/workflows/stop-workflow-task' // Base path without language prefix export type DocPathWithoutLangBase = @@ -251,70 +277,84 @@ export type DifyDocPath = `${DocLanguage}/${DocPathWithoutLang}` // API Reference path translations (English -> other languages) export const apiReferencePathTranslations: Record = { - '/api-reference/annotations/create-annotation': { zh: '/api-reference/标注管理/创建标注' }, - '/api-reference/annotations/delete-annotation': { zh: '/api-reference/标注管理/删除标注' }, - '/api-reference/annotations/get-annotation-list': { zh: '/api-reference/标注管理/获取标注列表' }, - '/api-reference/annotations/initial-annotation-reply-settings': { zh: '/api-reference/标注管理/标注回复初始设置' }, - '/api-reference/annotations/query-initial-annotation-reply-settings-task-status': { zh: '/api-reference/标注管理/查询标注回复初始设置任务状态' }, - '/api-reference/annotations/update-annotation': { zh: '/api-reference/标注管理/更新标注' }, - '/api-reference/application/get-application-basic-information': { zh: '/api-reference/应用设置/获取应用基本信息', ja: '/api-reference/アプリケーション情報/アプリケーションの基本情報を取得' }, - '/api-reference/application/get-application-meta-information': { zh: '/api-reference/应用配置/获取应用meta信息', ja: '/api-reference/アプリケーション設定/アプリケーションのメタ情報を取得' }, - '/api-reference/application/get-application-parameters-information': { zh: '/api-reference/应用设置/获取应用参数', ja: '/api-reference/アプリケーション情報/アプリケーションのパラメータ情報を取得' }, - '/api-reference/application/get-application-webapp-settings': { zh: '/api-reference/应用设置/获取应用-webapp-设置', ja: '/api-reference/アプリケーション情報/アプリのwebapp設定を取得' }, - '/api-reference/chat/next-suggested-questions': { zh: '/api-reference/对话消息/获取下一轮建议问题列表', ja: '/api-reference/チャットメッセージ/次の推奨質問' }, - '/api-reference/chat/send-chat-message': { zh: '/api-reference/对话消息/发送对话消息', ja: '/api-reference/チャットメッセージ/チャットメッセージを送信' }, - '/api-reference/chat/stop-chat-message-generation': { zh: '/api-reference/对话消息/停止响应', ja: '/api-reference/チャットメッセージ/生成停止' }, - '/api-reference/chatflow/next-suggested-questions': { zh: '/api-reference/对话消息/获取下一轮建议问题列表', ja: '/api-reference/チャットメッセージ/次の推奨質問' }, - '/api-reference/chatflow/send-chat-message': { zh: '/api-reference/对话消息/发送对话消息', ja: '/api-reference/チャットメッセージ/チャットメッセージを送信' }, - '/api-reference/chatflow/stop-advanced-chat-message-generation': { zh: '/api-reference/对话消息/停止响应', ja: '/api-reference/チャットメッセージ/生成を停止' }, - '/api-reference/chunks/add-chunks-to-a-document': { zh: '/api-reference/文档块/向文档添加块', ja: '/api-reference/チャンク/ドキュメントにチャンクを追加' }, - '/api-reference/chunks/create-child-chunk': { zh: '/api-reference/文档块/创建子块', ja: '/api-reference/チャンク/子チャンクを作成' }, - '/api-reference/chunks/delete-a-chunk-in-a-document': { zh: '/api-reference/文档块/删除文档中的块', ja: '/api-reference/チャンク/ドキュメント内のチャンクを削除' }, - '/api-reference/chunks/delete-child-chunk': { zh: '/api-reference/文档块/删除子块', ja: '/api-reference/チャンク/子チャンクを削除' }, - '/api-reference/chunks/get-a-chunk-details-in-a-document': { zh: '/api-reference/文档块/获取文档中的块详情', ja: '/api-reference/チャンク/ドキュメント内のチャンク詳細を取得' }, - '/api-reference/chunks/get-child-chunks': { zh: '/api-reference/文档块/获取子块', ja: '/api-reference/チャンク/子チャンクを取得' }, - '/api-reference/chunks/get-chunks-from-a-document': { zh: '/api-reference/文档块/从文档获取块', ja: '/api-reference/チャンク/ドキュメントからチャンクを取得' }, - '/api-reference/chunks/update-a-chunk-in-a-document': { zh: '/api-reference/文档块/更新文档中的块', ja: '/api-reference/チャンク/ドキュメント内のチャンクを更新' }, - '/api-reference/chunks/update-child-chunk': { zh: '/api-reference/文档块/更新子块', ja: '/api-reference/チャンク/子チャンクを更新' }, - '/api-reference/completion/create-completion-message': { zh: '/api-reference/文本生成/发送消息', ja: '/api-reference/完了メッセージ/完了メッセージの作成' }, - '/api-reference/completion/stop-generate': { zh: '/api-reference/文本生成/停止响应', ja: '/api-reference/完了メッセージ/生成の停止' }, - '/api-reference/conversations/conversation-rename': { zh: '/api-reference/会话管理/会话重命名', ja: '/api-reference/会話管理/会話の名前を変更' }, + '/api-reference/annotations/configure-annotation-reply': { zh: '/api-reference/标注管理/配置标注回复', ja: '/api-reference/アノテーション管理/アノテーション返信を設定' }, + '/api-reference/annotations/create-annotation': { zh: '/api-reference/标注管理/创建标注', ja: '/api-reference/アノテーション管理/アノテーションを作成' }, + '/api-reference/annotations/delete-annotation': { zh: '/api-reference/标注管理/删除标注', ja: '/api-reference/アノテーション管理/アノテーションを削除' }, + '/api-reference/annotations/get-annotation-reply-job-status': { zh: '/api-reference/标注管理/查询标注回复配置任务状态', ja: '/api-reference/アノテーション管理/アノテーション返信の初期設定タスクステータスを取得' }, + '/api-reference/annotations/list-annotations': { zh: '/api-reference/标注管理/获取标注列表', ja: '/api-reference/アノテーション管理/アノテーションリストを取得' }, + '/api-reference/annotations/update-annotation': { zh: '/api-reference/标注管理/更新标注', ja: '/api-reference/アノテーション管理/アノテーションを更新' }, + '/api-reference/applications/get-app-info': { zh: '/api-reference/应用配置/获取应用基本信息', ja: '/api-reference/アプリケーション設定/アプリケーションの基本情報を取得' }, + '/api-reference/applications/get-app-meta': { zh: '/api-reference/应用配置/获取应用元数据', ja: '/api-reference/アプリケーション設定/アプリケーションのメタ情報を取得' }, + '/api-reference/applications/get-app-parameters': { zh: '/api-reference/应用配置/获取应用参数', ja: '/api-reference/アプリケーション設定/アプリケーションのパラメータ情報を取得' }, + '/api-reference/applications/get-app-webapp-settings': { zh: '/api-reference/应用配置/获取应用-webapp-设置', ja: '/api-reference/アプリケーション設定/アプリの-webapp-設定を取得' }, + '/api-reference/chats/get-next-suggested-questions': { zh: '/api-reference/对话消息/获取下一轮建议问题列表', ja: '/api-reference/チャットメッセージ/次の推奨質問を取得' }, + '/api-reference/chats/send-chat-message': { zh: '/api-reference/对话消息/发送对话消息', ja: '/api-reference/チャットメッセージ/チャットメッセージを送信' }, + '/api-reference/chats/stop-chat-message-generation': { zh: '/api-reference/对话消息/停止响应', ja: '/api-reference/チャットメッセージ/生成を停止' }, + '/api-reference/chunks/create-child-chunk': { zh: '/api-reference/分段/创建子分段', ja: '/api-reference/チャンク/子チャンクを作成' }, + '/api-reference/chunks/create-chunks': { zh: '/api-reference/分段/向文档添加分段', ja: '/api-reference/チャンク/ドキュメントにチャンクを追加' }, + '/api-reference/chunks/delete-child-chunk': { zh: '/api-reference/分段/删除子分段', ja: '/api-reference/チャンク/子チャンクを削除' }, + '/api-reference/chunks/delete-chunk': { zh: '/api-reference/分段/删除文档中的分段', ja: '/api-reference/チャンク/ドキュメント内のチャンクを削除' }, + '/api-reference/chunks/get-chunk': { zh: '/api-reference/分段/获取文档中的分段详情', ja: '/api-reference/チャンク/ドキュメント内のチャンク詳細を取得' }, + '/api-reference/chunks/list-child-chunks': { zh: '/api-reference/分段/获取子分段', ja: '/api-reference/チャンク/子チャンク一覧を取得' }, + '/api-reference/chunks/list-chunks': { zh: '/api-reference/分段/从文档获取分段', ja: '/api-reference/チャンク/チャンク一覧を取得' }, + '/api-reference/chunks/update-child-chunk': { zh: '/api-reference/分段/更新子分段', ja: '/api-reference/チャンク/子チャンクを更新' }, + '/api-reference/chunks/update-chunk': { zh: '/api-reference/分段/更新文档中的分段', ja: '/api-reference/チャンク/ドキュメント内のチャンクを更新' }, + '/api-reference/completions/send-completion-message': { zh: '/api-reference/文本生成/发送消息', ja: '/api-reference/完了メッセージ/完了メッセージを送信' }, + '/api-reference/completions/stop-completion-message-generation': { zh: '/api-reference/文本生成/停止响应', ja: '/api-reference/完了メッセージ/生成を停止' }, '/api-reference/conversations/delete-conversation': { zh: '/api-reference/会话管理/删除会话', ja: '/api-reference/会話管理/会話を削除' }, - '/api-reference/conversations/get-conversation-history-messages': { zh: '/api-reference/会话管理/获取会话历史消息', ja: '/api-reference/会話管理/会話履歴メッセージを取得' }, - '/api-reference/conversations/get-conversation-variables': { zh: '/api-reference/会话管理/获取对话变量', ja: '/api-reference/会話管理/会話変数の取得' }, - '/api-reference/conversations/get-conversations': { zh: '/api-reference/会话管理/获取会话列表', ja: '/api-reference/会話管理/会話を取得' }, - '/api-reference/datasets/create-an-empty-knowledge-base': { zh: '/api-reference/数据集/创建空知识库', ja: '/api-reference/データセット/空のナレッジベースを作成' }, - '/api-reference/datasets/delete-a-knowledge-base': { zh: '/api-reference/数据集/删除知识库', ja: '/api-reference/データセット/ナレッジベースを削除' }, - '/api-reference/datasets/get-knowledge-base-details': { zh: '/api-reference/数据集/获取知识库详情', ja: '/api-reference/データセット/ナレッジベース詳細を取得' }, - '/api-reference/datasets/get-knowledge-base-list': { zh: '/api-reference/数据集/获取知识库列表', ja: '/api-reference/データセット/ナレッジベースリストを取得' }, - '/api-reference/datasets/retrieve-chunks-from-a-knowledge-base-/-test-retrieval': { zh: '/api-reference/数据集/从知识库检索块-/-测试检索', ja: '/api-reference/データセット/ナレッジベースからチャンクを取得-/-テスト検索' }, - '/api-reference/datasets/update-knowledge-base': { zh: '/api-reference/数据集/更新知识库', ja: '/api-reference/データセット/ナレッジベースを更新' }, - '/api-reference/documents/create-a-document-from-a-file': { zh: '/api-reference/文档/从文件创建文档', ja: '/api-reference/ドキュメント/ファイルからドキュメントを作成' }, - '/api-reference/documents/create-a-document-from-text': { zh: '/api-reference/文档/从文本创建文档', ja: '/api-reference/ドキュメント/テキストからドキュメントを作成' }, - '/api-reference/documents/delete-a-document': { zh: '/api-reference/文档/删除文档', ja: '/api-reference/ドキュメント/ドキュメントを削除' }, - '/api-reference/documents/get-document-detail': { zh: '/api-reference/文档/获取文档详情', ja: '/api-reference/ドキュメント/ドキュメント詳細を取得' }, - '/api-reference/documents/get-document-embedding-status-(progress)': { zh: '/api-reference/文档/获取文档嵌入状态(进度)', ja: '/api-reference/ドキュメント/ドキュメント埋め込みステータス(進捗)を取得' }, - '/api-reference/documents/get-the-document-list-of-a-knowledge-base': { zh: '/api-reference/文档/获取知识库的文档列表', ja: '/api-reference/ドキュメント/ナレッジベースのドキュメントリストを取得' }, - '/api-reference/documents/update-a-document-with-a-file': { zh: '/api-reference/文档/用文件更新文档', ja: '/api-reference/ドキュメント/ファイルでドキュメントを更新' }, - '/api-reference/documents/update-a-document-with-text': { zh: '/api-reference/文档/用文本更新文档', ja: '/api-reference/ドキュメント/テキストでドキュメントを更新' }, - '/api-reference/documents/update-document-status': { zh: '/api-reference/文档/更新文档状态', ja: '/api-reference/ドキュメント/ドキュメントステータスを更新' }, - '/api-reference/feedback/get-feedbacks-of-application': { zh: '/api-reference/反馈/获取应用反馈列表', ja: '/api-reference/メッセージフィードバック/アプリのメッセージの「いいね」とフィードバックを取得' }, - '/api-reference/feedback/message-feedback': { zh: '/api-reference/反馈/消息反馈(点赞)', ja: '/api-reference/メッセージフィードバック/メッセージフィードバック' }, - '/api-reference/files/file-preview': { zh: '/api-reference/文件操作/文件预览', ja: '/api-reference/ファイル操作/ファイルプレビュー' }, - '/api-reference/files/file-upload': { zh: '/api-reference/文件管理/上传文件', ja: '/api-reference/ファイル操作/ファイルアップロード' }, - '/api-reference/files/file-upload-for-workflow': { zh: '/api-reference/文件操作-(workflow)/上传文件-(workflow)', ja: '/api-reference/ファイル操作-(ワークフロー)/ファイルアップロード-(ワークフロー用)' }, - '/api-reference/metadata-&-tags/bind-dataset-to-knowledge-base-type-tag': { zh: '/api-reference/元数据和标签/将数据集绑定到知识库类型标签', ja: '/api-reference/メタデータ・タグ/データセットをナレッジベースタイプタグにバインド' }, - '/api-reference/metadata-&-tags/create-new-knowledge-base-type-tag': { zh: '/api-reference/元数据和标签/创建新的知识库类型标签', ja: '/api-reference/メタデータ・タグ/新しいナレッジベースタイプタグを作成' }, - '/api-reference/metadata-&-tags/delete-knowledge-base-type-tag': { zh: '/api-reference/元数据和标签/删除知识库类型标签', ja: '/api-reference/メタデータ・タグ/ナレッジベースタイプタグを削除' }, - '/api-reference/metadata-&-tags/get-knowledge-base-type-tags': { zh: '/api-reference/元数据和标签/获取知识库类型标签', ja: '/api-reference/メタデータ・タグ/ナレッジベースタイプタグを取得' }, - '/api-reference/metadata-&-tags/modify-knowledge-base-type-tag-name': { zh: '/api-reference/元数据和标签/修改知识库类型标签名称', ja: '/api-reference/メタデータ・タグ/ナレッジベースタイプタグ名を変更' }, - '/api-reference/metadata-&-tags/query-tags-bound-to-a-dataset': { zh: '/api-reference/元数据和标签/查询绑定到数据集的标签', ja: '/api-reference/メタデータ・タグ/データセットにバインドされたタグをクエリ' }, - '/api-reference/metadata-&-tags/unbind-dataset-and-knowledge-base-type-tag': { zh: '/api-reference/元数据和标签/解绑数据集和知识库类型标签', ja: '/api-reference/メタデータ・タグ/データセットとナレッジベースタイプタグのバインドを解除' }, - '/api-reference/models/get-available-embedding-models': { zh: '/api-reference/模型/获取可用的嵌入模型', ja: '/api-reference/モデル/利用可能な埋め込みモデルを取得' }, - '/api-reference/tts/speech-to-text': { zh: '/api-reference/语音与文字转换/语音转文字', ja: '/api-reference/音声・テキスト変換/音声からテキストへ' }, - '/api-reference/tts/text-to-audio': { zh: '/api-reference/语音服务/文字转语音', ja: '/api-reference/音声変換/テキストから音声' }, - '/api-reference/workflow-execution/execute-workflow': { zh: '/api-reference/工作流执行/执行-workflow', ja: '/api-reference/ワークフロー実行/ワークフローを実行' }, - '/api-reference/workflow-execution/get-workflow-logs': { zh: '/api-reference/工作流执行/获取-workflow-日志', ja: '/api-reference/ワークフロー実行/ワークフローログを取得' }, - '/api-reference/workflow-execution/get-workflow-run-detail': { zh: '/api-reference/工作流执行/获取workflow执行情况', ja: '/api-reference/ワークフロー実行/ワークフロー実行詳細を取得' }, - '/api-reference/workflow-execution/stop-workflow-task-generation': { zh: '/api-reference/工作流执行/停止响应-(workflow-task)', ja: '/api-reference/ワークフロー実行/生成を停止-(ワークフロータスク)' }, + '/api-reference/conversations/list-conversation-messages': { zh: '/api-reference/会话管理/获取会话历史消息', ja: '/api-reference/会話管理/会話履歴メッセージ一覧を取得' }, + '/api-reference/conversations/list-conversation-variables': { zh: '/api-reference/会话管理/获取对话变量', ja: '/api-reference/会話管理/会話変数の取得' }, + '/api-reference/conversations/list-conversations': { zh: '/api-reference/会话管理/获取会话列表', ja: '/api-reference/会話管理/会話一覧を取得' }, + '/api-reference/conversations/rename-conversation': { zh: '/api-reference/会话管理/重命名会话', ja: '/api-reference/会話管理/会話の名前を変更' }, + '/api-reference/conversations/update-conversation-variable': { zh: '/api-reference/会话管理/更新对话变量', ja: '/api-reference/会話管理/会話変数を更新' }, + '/api-reference/documents/create-document-by-file': { zh: '/api-reference/文档/从文件创建文档', ja: '/api-reference/ドキュメント/ファイルからドキュメントを作成' }, + '/api-reference/documents/create-document-by-text': { zh: '/api-reference/文档/从文本创建文档', ja: '/api-reference/ドキュメント/テキストからドキュメントを作成' }, + '/api-reference/documents/delete-document': { zh: '/api-reference/文档/删除文档', ja: '/api-reference/ドキュメント/ドキュメントを削除' }, + '/api-reference/documents/download-document': { zh: '/api-reference/文档/下载文档', ja: '/api-reference/ドキュメント/ドキュメントをダウンロード' }, + '/api-reference/documents/download-documents-as-zip': { zh: '/api-reference/文档/批量下载文档(zip)', ja: '/api-reference/ドキュメント/ドキュメントを一括ダウンロード(zip)' }, + '/api-reference/documents/get-document': { zh: '/api-reference/文档/获取文档详情', ja: '/api-reference/ドキュメント/ドキュメント詳細を取得' }, + '/api-reference/documents/get-document-indexing-status': { zh: '/api-reference/文档/获取文档嵌入状态(进度)', ja: '/api-reference/ドキュメント/ドキュメント埋め込みステータス(進捗)を取得' }, + '/api-reference/documents/list-documents': { zh: '/api-reference/文档/获取知识库的文档列表', ja: '/api-reference/ドキュメント/ナレッジベースのドキュメントリストを取得' }, + '/api-reference/documents/update-document-by-file': { zh: '/api-reference/文档/用文件更新文档', ja: '/api-reference/ドキュメント/ファイルでドキュメントを更新' }, + '/api-reference/documents/update-document-by-text': { zh: '/api-reference/文档/用文本更新文档', ja: '/api-reference/ドキュメント/テキストでドキュメントを更新' }, + '/api-reference/documents/update-document-status-in-batch': { zh: '/api-reference/文档/批量更新文档状态', ja: '/api-reference/ドキュメント/ドキュメントステータスを一括更新' }, + '/api-reference/end-users/get-end-user-info': { zh: '/api-reference/终端用户/获取终端用户', ja: '/api-reference/エンドユーザー/エンドユーザー取得' }, + '/api-reference/feedback/list-app-feedbacks': { zh: '/api-reference/消息反馈/获取应用的消息反馈', ja: '/api-reference/メッセージフィードバック/アプリのフィードバック一覧を取得' }, + '/api-reference/feedback/submit-message-feedback': { zh: '/api-reference/消息反馈/提交消息反馈', ja: '/api-reference/メッセージフィードバック/メッセージフィードバックを送信' }, + '/api-reference/files/download-file': { zh: '/api-reference/文件操作/下载文件', ja: '/api-reference/ファイル操作/ファイルをダウンロード' }, + '/api-reference/files/upload-file': { zh: '/api-reference/文件操作/上传文件', ja: '/api-reference/ファイル操作/ファイルをアップロード' }, + '/api-reference/knowledge-bases/create-an-empty-knowledge-base': { zh: '/api-reference/知识库/创建空知识库', ja: '/api-reference/データセット/空のナレッジベースを作成' }, + '/api-reference/knowledge-bases/delete-knowledge-base': { zh: '/api-reference/知识库/删除知识库', ja: '/api-reference/データセット/ナレッジベースを削除' }, + '/api-reference/knowledge-bases/get-knowledge-base': { zh: '/api-reference/知识库/获取知识库详情', ja: '/api-reference/データセット/ナレッジベース詳細を取得' }, + '/api-reference/knowledge-bases/list-knowledge-bases': { zh: '/api-reference/知识库/获取知识库列表', ja: '/api-reference/データセット/ナレッジベースリストを取得' }, + '/api-reference/knowledge-bases/retrieve-chunks-from-a-knowledge-base-/-test-retrieval': { zh: '/api-reference/知识库/从知识库检索分段-/-测试检索', ja: '/api-reference/データセット/ナレッジベースからチャンクを取得-/-テスト検索' }, + '/api-reference/knowledge-bases/update-knowledge-base': { zh: '/api-reference/知识库/更新知识库', ja: '/api-reference/データセット/ナレッジベースを更新' }, + '/api-reference/knowledge-pipeline/list-datasource-plugins': { zh: '/api-reference/知识流水线/获取数据源插件列表', ja: '/api-reference/ナレッジパイプライン/データソースプラグインリストを取得' }, + '/api-reference/knowledge-pipeline/run-datasource-node': { zh: '/api-reference/知识流水线/执行数据源节点', ja: '/api-reference/ナレッジパイプライン/データソースノードを実行' }, + '/api-reference/knowledge-pipeline/run-pipeline': { zh: '/api-reference/知识流水线/运行流水线', ja: '/api-reference/ナレッジパイプライン/パイプラインを実行' }, + '/api-reference/knowledge-pipeline/upload-pipeline-file': { zh: '/api-reference/知识流水线/上传流水线文件', ja: '/api-reference/ナレッジパイプライン/パイプラインファイルをアップロード' }, + '/api-reference/metadata/create-metadata-field': { zh: '/api-reference/元数据/创建元数据字段', ja: '/api-reference/メタデータ/メタデータフィールドを作成' }, + '/api-reference/metadata/delete-metadata-field': { zh: '/api-reference/元数据/删除元数据字段', ja: '/api-reference/メタデータ/メタデータフィールドを削除' }, + '/api-reference/metadata/get-built-in-metadata-fields': { zh: '/api-reference/元数据/获取内置元数据字段', ja: '/api-reference/メタデータ/組み込みメタデータフィールドを取得' }, + '/api-reference/metadata/list-metadata-fields': { zh: '/api-reference/元数据/获取元数据字段列表', ja: '/api-reference/メタデータ/メタデータフィールドリストを取得' }, + '/api-reference/metadata/update-built-in-metadata-field': { zh: '/api-reference/元数据/更新内置元数据字段', ja: '/api-reference/メタデータ/組み込みメタデータフィールドを更新' }, + '/api-reference/metadata/update-document-metadata-in-batch': { zh: '/api-reference/元数据/批量更新文档元数据', ja: '/api-reference/メタデータ/ドキュメントメタデータを一括更新' }, + '/api-reference/metadata/update-metadata-field': { zh: '/api-reference/元数据/更新元数据字段', ja: '/api-reference/メタデータ/メタデータフィールドを更新' }, + '/api-reference/models/get-available-models': { zh: '/api-reference/模型/获取可用模型', ja: '/api-reference/モデル/利用可能なモデルを取得' }, + '/api-reference/tags/create-knowledge-tag': { zh: '/api-reference/标签/创建知识库标签', ja: '/api-reference/タグ管理/ナレッジベースタグを作成' }, + '/api-reference/tags/create-tag-binding': { zh: '/api-reference/标签/绑定标签到知识库', ja: '/api-reference/タグ管理/タグをデータセットにバインド' }, + '/api-reference/tags/delete-knowledge-tag': { zh: '/api-reference/标签/删除知识库标签', ja: '/api-reference/タグ管理/ナレッジベースタグを削除' }, + '/api-reference/tags/delete-tag-binding': { zh: '/api-reference/标签/解除标签与知识库的绑定', ja: '/api-reference/タグ管理/タグとデータセットのバインドを解除' }, + '/api-reference/tags/get-knowledge-base-tags': { zh: '/api-reference/标签/获取知识库绑定的标签', ja: '/api-reference/タグ管理/ナレッジベースにバインドされたタグを取得' }, + '/api-reference/tags/list-knowledge-tags': { zh: '/api-reference/标签/获取知识库标签列表', ja: '/api-reference/タグ管理/ナレッジベースタグリストを取得' }, + '/api-reference/tags/update-knowledge-tag': { zh: '/api-reference/标签/修改知识库标签', ja: '/api-reference/タグ管理/ナレッジベースタグを変更' }, + '/api-reference/tts/convert-audio-to-text': { zh: '/api-reference/语音与文字转换/语音转文字', ja: '/api-reference/音声・テキスト変換/音声をテキストに変換' }, + '/api-reference/tts/convert-text-to-audio': { zh: '/api-reference/语音与文字转换/文字转语音', ja: '/api-reference/音声・テキスト変換/テキストを音声に変換' }, + '/api-reference/workflow-runs/get-workflow-run-detail': { zh: '/api-reference/工作流执行/获取工作流执行情况', ja: '/api-reference/ワークフロー実行/ワークフロー実行詳細を取得' }, + '/api-reference/workflow-runs/list-workflow-logs': { zh: '/api-reference/工作流执行/获取工作流日志', ja: '/api-reference/ワークフロー実行/ワークフローログ一覧を取得' }, + '/api-reference/workflows/get-workflow-run-detail': { zh: '/api-reference/工作流/获取工作流执行情况', ja: '/api-reference/ワークフロー/ワークフロー実行詳細を取得' }, + '/api-reference/workflows/list-workflow-logs': { zh: '/api-reference/工作流/获取工作流日志', ja: '/api-reference/ワークフロー/ワークフローログ一覧を取得' }, + '/api-reference/workflows/run-workflow': { zh: '/api-reference/工作流/执行工作流', ja: '/api-reference/ワークフロー/ワークフローを実行' }, + '/api-reference/workflows/run-workflow-by-id': { zh: '/api-reference/工作流/按-id-执行工作流', ja: '/api-reference/ワークフロー/id-でワークフローを実行' }, + '/api-reference/workflows/stop-workflow-task': { zh: '/api-reference/工作流/停止工作流任务', ja: '/api-reference/ワークフロー/ワークフロータスクを停止' }, } From a946015ebfbf15caae318bd7aca8787b19cd9eb5 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Wed, 25 Mar 2026 04:39:58 +0100 Subject: [PATCH 4/8] test: replace indexing_technique string literals with IndexTechnique (#34042) --- .../test_paragraph_index_processor.py | 13 ++++--- .../test_parent_child_index_processor.py | 3 +- .../processor/test_qa_index_processor.py | 7 ++-- .../core/rag/indexing/test_indexing_runner.py | 18 ++++----- .../test_knowledge_index_node.py | 3 +- .../unit_tests/models/test_dataset_models.py | 21 +++++----- .../services/dataset_service_update_delete.py | 19 +++++----- .../services/document_service_validation.py | 14 +++---- .../unit_tests/services/segment_service.py | 16 ++++---- .../test_dataset_service_lock_not_owned.py | 10 ++--- .../services/test_summary_index_service.py | 16 ++++---- .../services/test_vector_service.py | 38 +++++++++---------- .../unit_tests/services/vector_service.py | 36 +++++++++--------- .../tasks/test_clean_dataset_task.py | 16 ++++---- .../tasks/test_dataset_indexing_task.py | 4 +- 15 files changed, 120 insertions(+), 114 deletions(-) diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index e6cc582398..2c234edd9a 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -4,6 +4,7 @@ from unittest.mock import Mock, patch import pytest from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.models.document import AttachmentDocument, Document from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage @@ -21,7 +22,7 @@ class TestParagraphIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -167,7 +168,7 @@ class TestParagraphIndexProcessor: def test_load_uses_keyword_add_texts_with_keywords_when_economy( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="chunk", metadata={})] with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: @@ -178,7 +179,7 @@ class TestParagraphIndexProcessor: def test_load_uses_keyword_add_texts_without_keywords_when_economy( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="chunk", metadata={})] with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: @@ -208,7 +209,7 @@ class TestParagraphIndexProcessor: def test_clean_economy_deletes_summaries_and_keywords( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with ( patch( @@ -222,7 +223,7 @@ class TestParagraphIndexProcessor: mock_keyword_cls.return_value.delete.assert_called_once() def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: processor.clean(dataset, ["node-2"], with_keywords=True) @@ -267,7 +268,7 @@ class TestParagraphIndexProcessor: def test_index_list_chunks_economy( self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with ( patch( "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index 5c78cae7c1..b1ed735ee7 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor from core.rag.models.document import AttachmentDocument, ChildDocument, Document from services.entities.knowledge_entities.knowledge_entities import ParentMode @@ -19,7 +20,7 @@ class TestParentChildIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 99323eeec9..98c47bec8f 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -6,6 +6,7 @@ import pytest from werkzeug.datastructures import FileStorage from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor from core.rag.models.document import AttachmentDocument, Document @@ -33,7 +34,7 @@ class TestQAIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -207,7 +208,7 @@ class TestQAIndexProcessor: vector.create_multimodal.assert_called_once_with(multimodal_docs) def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="Q1", metadata={"answer": "A1"})] with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: @@ -298,7 +299,7 @@ class TestQAIndexProcessor: def test_index_requires_high_quality( self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) with ( diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index b011ade884..b54a74b69c 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -61,7 +61,7 @@ from core.indexing_runner import ( DocumentIsPausedError, IndexingRunner, ) -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document from dify_graph.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now @@ -76,7 +76,7 @@ from models.dataset import Document as DatasetDocument def create_mock_dataset( dataset_id: str | None = None, tenant_id: str | None = None, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", ) -> Mock: @@ -458,7 +458,7 @@ class TestIndexingRunnerTransform: dataset = Mock(spec=Dataset) dataset.id = str(uuid.uuid4()) dataset.tenant_id = str(uuid.uuid4()) - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -521,7 +521,7 @@ class TestIndexingRunnerTransform: """Test transformation with economy indexing (no embeddings).""" # Arrange runner = IndexingRunner() - sample_dataset.indexing_technique = "economy" + sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() transformed_docs = [ @@ -605,7 +605,7 @@ class TestIndexingRunnerLoad: dataset = Mock(spec=Dataset) dataset.id = str(uuid.uuid4()) dataset.tenant_id = str(uuid.uuid4()) - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -674,7 +674,7 @@ class TestIndexingRunnerLoad: """Test loading with economy indexing (keyword only).""" # Arrange runner = IndexingRunner() - sample_dataset.indexing_technique = "economy" + sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() @@ -701,7 +701,7 @@ class TestIndexingRunnerLoad: # Arrange runner = IndexingRunner() sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - sample_dataset.indexing_technique = "high_quality" + sample_dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # Add child documents for doc in sample_documents: @@ -795,7 +795,7 @@ class TestIndexingRunnerRun: mock_dataset = Mock(spec=Dataset) mock_dataset.id = doc.dataset_id mock_dataset.tenant_id = doc.tenant_id - mock_dataset.indexing_technique = "economy" + mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset mock_process_rule = Mock(spec=DatasetProcessRule) @@ -949,7 +949,7 @@ class TestIndexingRunnerRun: mock_dependencies["db"].session.get.side_effect = get_side_effect mock_dataset = Mock(spec=Dataset) - mock_dataset.indexing_technique = "economy" + mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset mock_process_rule = Mock(spec=DatasetProcessRule) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index 33f7ace5ab..feb560bbc3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -5,6 +5,7 @@ from unittest.mock import Mock import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode @@ -78,7 +79,7 @@ def sample_node_data(): type="knowledge-index", chunk_structure="general_structure", index_chunk_variable_selector=["start", "chunks"], - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, summary_index_setting=None, ) diff --git a/api/tests/unit_tests/models/test_dataset_models.py b/api/tests/unit_tests/models/test_dataset_models.py index 98dd07907a..6c8a91129b 100644 --- a/api/tests/unit_tests/models/test_dataset_models.py +++ b/api/tests/unit_tests/models/test_dataset_models.py @@ -15,6 +15,7 @@ from datetime import UTC, datetime from unittest.mock import patch from uuid import uuid4 +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import ( AppDatasetJoin, ChildChunk, @@ -67,14 +68,14 @@ class TestDatasetModelValidation: data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), description="Test description", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model="text-embedding-ada-002", embedding_model_provider="openai", ) # Assert assert dataset.description == "Test description" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" @@ -86,21 +87,21 @@ class TestDatasetModelValidation: name="High Quality Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) dataset_economy = Dataset( tenant_id=str(uuid4()), name="Economy Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) # Assert - assert dataset_high_quality.indexing_technique == "high_quality" - assert dataset_economy.indexing_technique == "economy" - assert "high_quality" in Dataset.INDEXING_TECHNIQUE_LIST - assert "economy" in Dataset.INDEXING_TECHNIQUE_LIST + assert dataset_high_quality.indexing_technique == IndexTechniqueType.HIGH_QUALITY + assert dataset_economy.indexing_technique == IndexTechniqueType.ECONOMY + assert IndexTechniqueType.HIGH_QUALITY in Dataset.INDEXING_TECHNIQUE_LIST + assert IndexTechniqueType.ECONOMY in Dataset.INDEXING_TECHNIQUE_LIST def test_dataset_provider_validation(self): """Test dataset provider values.""" @@ -983,7 +984,7 @@ class TestModelIntegration: name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) dataset.id = dataset_id @@ -1019,7 +1020,7 @@ class TestModelIntegration: assert document.dataset_id == dataset_id assert segment.dataset_id == dataset_id assert segment.document_id == document_id - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert document.word_count == 100 assert segment.status == SegmentStatus.COMPLETED diff --git a/api/tests/unit_tests/services/dataset_service_update_delete.py b/api/tests/unit_tests/services/dataset_service_update_delete.py index c805dd98e2..424ac18870 100644 --- a/api/tests/unit_tests/services/dataset_service_update_delete.py +++ b/api/tests/unit_tests/services/dataset_service_update_delete.py @@ -97,6 +97,7 @@ from unittest.mock import Mock, create_autospec, patch import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -149,7 +150,7 @@ class DatasetUpdateDeleteTestDataFactory: name: str = "Test Dataset", description: str = "Test description", tenant_id: str = "tenant-123", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str | None = "openai", embedding_model: str | None = "text-embedding-ada-002", collection_binding_id: str | None = "binding-123", @@ -237,7 +238,7 @@ class DatasetUpdateDeleteTestDataFactory: @staticmethod def create_knowledge_configuration_mock( chunk_structure: str = "tree", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", keyword_number: int = 10, @@ -630,12 +631,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset_id="dataset-123", runtime_mode="rag_pipeline", chunk_structure="tree", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( chunk_structure="list", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) @@ -671,7 +672,7 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: # Assert assert dataset.chunk_structure == "list" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == "binding-123" @@ -698,12 +699,12 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset_id="dataset-123", runtime_mode="rag_pipeline", chunk_structure="tree", # Existing structure - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( chunk_structure="list", # Different structure - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) mock_session.merge.return_value = dataset @@ -735,11 +736,11 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock( dataset_id="dataset-123", runtime_mode="rag_pipeline", - indexing_technique="high_quality", # Current technique + indexing_technique=IndexTechniqueType.HIGH_QUALITY, # Current technique ) knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( - indexing_technique="economy", # Trying to change to economy + indexing_technique=IndexTechniqueType.ECONOMY, # Trying to change to economy ) mock_session.merge.return_value = dataset diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 1f68ff6b3d..49fdc5cc9b 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -111,7 +111,7 @@ from unittest.mock import Mock, patch import pytest from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService @@ -154,7 +154,7 @@ class DocumentValidationTestDataFactory: dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", doc_form: str | None = None, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", **kwargs, @@ -190,7 +190,7 @@ class DocumentValidationTestDataFactory: data_source: DataSource | None = None, process_rule: ProcessRule | None = None, doc_form: str = IndexStructureType.PARAGRAPH_INDEX, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, **kwargs, ) -> Mock: """ @@ -448,7 +448,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) @@ -481,7 +481,7 @@ class TestDatasetServiceCheckDatasetModelSetting: - No errors are raised """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = DocumentValidationTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) # Act (should not raise) DatasetService.check_dataset_model_setting(dataset) @@ -503,7 +503,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="invalid-model", ) @@ -533,7 +533,7 @@ class TestDatasetServiceCheckDatasetModelSetting: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock( - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", ) diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py index 5e625fa0cd..14af7f7119 100644 --- a/api/tests/unit_tests/services/segment_service.py +++ b/api/tests/unit_tests/services/segment_service.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account from models.dataset import ChildChunk, Dataset, Document, DocumentSegment from models.enums import SegmentType @@ -111,7 +111,7 @@ class SegmentTestDataFactory: def create_dataset_mock( dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model: str = "text-embedding-ada-002", embedding_model_provider: str = "openai", **kwargs, @@ -163,7 +163,7 @@ class TestSegmentServiceCreateSegment: """Test successful creation of a segment.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "New segment content", "keywords": ["test", "segment"]} mock_query = MagicMock() @@ -212,7 +212,7 @@ class TestSegmentServiceCreateSegment: """Test creation of segment with QA model (requires answer).""" # Arrange document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]} mock_query = MagicMock() @@ -247,7 +247,7 @@ class TestSegmentServiceCreateSegment: """Test creation of segment with high quality indexing technique.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) args = {"content": "New segment content", "keywords": ["test"]} mock_query = MagicMock() @@ -289,7 +289,7 @@ class TestSegmentServiceCreateSegment: """Test segment creation when vector indexing fails.""" # Arrange document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = {"content": "New segment content", "keywords": ["test"]} mock_query = MagicMock() @@ -342,7 +342,7 @@ class TestSegmentServiceUpdateSegment: # Arrange segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) document = SegmentTestDataFactory.create_document_mock(word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = SegmentUpdateArgs(content="Updated content", keywords=["updated"]) mock_db_session.query.return_value.where.return_value.first.return_value = segment @@ -431,7 +431,7 @@ class TestSegmentServiceUpdateSegment: # Arrange segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) - dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"]) mock_db_session.query.return_value.where.return_value.first.return_value = segment diff --git a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py index d2287e8982..9a513c3fe6 100644 --- a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py +++ b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py @@ -4,7 +4,7 @@ from unittest.mock import Mock, create_autospec import pytest from redis.exceptions import LockNotOwnedError -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account from models.dataset import Dataset, Document from services.dataset_service import DocumentService, SegmentService @@ -71,7 +71,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned( dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id dataset.data_source_type = "upload_file" - dataset.indexing_technique = "high_quality" # so we skip re-initialization branch + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # so we skip re-initialization branch # Minimal knowledge_config stub that satisfies pre-lock code info_list = types.SimpleNamespace(data_source_type="upload_file") @@ -80,7 +80,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned( doc_form=IndexStructureType.QA_INDEX, original_document_id=None, # go into "new document" branch data_source=data_source, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model=None, embedding_model_provider=None, retrieval_model=None, @@ -126,7 +126,7 @@ def test_add_segment_ignores_lock_not_owned( dataset = create_autospec(Dataset, instance=True) dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id - dataset.indexing_technique = "economy" # skip embedding/token calculation branch + dataset.indexing_technique = IndexTechniqueType.ECONOMY # skip embedding/token calculation branch document = create_autospec(Document, instance=True) document.id = "doc-1" @@ -169,7 +169,7 @@ def test_multi_create_segment_ignores_lock_not_owned( dataset = create_autospec(Dataset, instance=True) dataset.id = "ds-1" dataset.tenant_id = fake_current_user.current_tenant_id - dataset.indexing_technique = "economy" # again, skip high_quality path + dataset.indexing_technique = IndexTechniqueType.ECONOMY # again, skip high_quality path document = create_autospec(Document, instance=True) document.id = "doc-1" diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py index c4285c73a0..ef53df9350 100644 --- a/api/tests/unit_tests/services/test_summary_index_service.py +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -11,7 +11,7 @@ from unittest.mock import MagicMock import pytest import services.summary_index_service as summary_module -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.enums import SegmentStatus, SummaryStatus from services.summary_index_service import SummaryIndexService @@ -27,7 +27,7 @@ class _SessionContext: return None -def _dataset(*, indexing_technique: str = "high_quality") -> MagicMock: +def _dataset(*, indexing_technique: str = IndexTechniqueType.HIGH_QUALITY) -> MagicMock: dataset = MagicMock(name="dataset") dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" @@ -169,7 +169,8 @@ def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> N def test_vectorize_summary_skips_non_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: vector_cls = MagicMock() monkeypatch.setattr(summary_module, "Vector", vector_cls) - SummaryIndexService.vectorize_summary(_summary_record(), _segment(), _dataset(indexing_technique="economy")) + dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) + SummaryIndexService.vectorize_summary(_summary_record(), _segment(), dataset) vector_cls.assert_not_called() @@ -621,7 +622,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _dataset(indexing_technique="economy") + dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" document.doc_form = IndexStructureType.PARAGRAPH_INDEX @@ -778,7 +779,7 @@ def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mo def test_enable_summaries_for_segments_skips_non_high_quality() -> None: - SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique="economy")) + SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique=IndexTechniqueType.ECONOMY)) def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pytest.MonkeyPatch) -> None: @@ -932,9 +933,8 @@ def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.Mon def test_update_summary_for_segment_skip_conditions() -> None: - assert ( - SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None - ) + economy_dataset = _dataset(indexing_technique=IndexTechniqueType.ECONOMY) + assert SummaryIndexService.update_summary_for_segment(_segment(), economy_dataset, "x") is None seg = _segment(has_document=True) seg.document.doc_form = IndexStructureType.QA_INDEX assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py index d3a98dd4bb..16d3011810 100644 --- a/api/tests/unit_tests/services/test_vector_service.py +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock import pytest import services.vector_service as vector_service_module -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from services.vector_service import VectorService @@ -32,7 +32,7 @@ class _ParentDocStub: def _make_dataset( *, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, doc_form: str = IndexStructureType.PARAGRAPH_INDEX, tenant_id: str = "tenant-1", dataset_id: str = "dataset-1", @@ -192,7 +192,7 @@ def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_ex dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, embedding_model_provider="openai", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) segment = _make_segment() @@ -241,7 +241,7 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, embedding_model_provider=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) segment = _make_segment() @@ -329,7 +329,7 @@ def test_create_segments_vector_parent_child_missing_processing_rule_raises(monk def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None: dataset = _make_dataset( doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) segment = _make_segment() dataset_document = MagicMock() @@ -348,7 +348,7 @@ def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) segment = _make_segment() vector_instance = MagicMock() @@ -364,7 +364,7 @@ def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.Monk def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) segment = _make_segment() keyword_instance = MagicMock() @@ -380,7 +380,7 @@ def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypat def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) segment = _make_segment() keyword_instance = MagicMock() @@ -473,7 +473,7 @@ def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = MagicMock() child_chunk.content = "child" child_chunk.index_node_id = "id" @@ -489,7 +489,7 @@ def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.M def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) vector_cls = MagicMock() monkeypatch.setattr(vector_service_module, "Vector", vector_cls) @@ -505,7 +505,7 @@ def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = MagicMock() new_chunk.content = "n" @@ -536,7 +536,7 @@ def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pyte def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy") + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY) vector_cls = MagicMock() monkeypatch.setattr(vector_service_module, "Vector", vector_cls) VectorService.update_child_chunk_vector([], [], [], dataset) @@ -561,7 +561,7 @@ def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="economy", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.ECONOMY, is_multimodal=True) segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}]) vector_cls = MagicMock() @@ -575,7 +575,7 @@ def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pyt def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}]) vector_cls = MagicMock() @@ -591,7 +591,7 @@ def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pyt def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}]) vector_instance = MagicMock(name="vector_instance") @@ -612,7 +612,7 @@ def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -630,7 +630,7 @@ def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -663,7 +663,7 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops( monkeypatch: pytest.MonkeyPatch, ) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=False) segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() @@ -683,7 +683,7 @@ def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + dataset = _make_dataset(indexing_technique=IndexTechniqueType.HIGH_QUALITY, is_multimodal=True) segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) vector_instance = MagicMock() diff --git a/api/tests/unit_tests/services/vector_service.py b/api/tests/unit_tests/services/vector_service.py index e180063041..33a5607ef4 100644 --- a/api/tests/unit_tests/services/vector_service.py +++ b/api/tests/unit_tests/services/vector_service.py @@ -121,7 +121,7 @@ import pytest from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import Document from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment from services.vector_service import VectorService @@ -153,7 +153,7 @@ class VectorServiceTestDataFactory: dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", doc_form: str = IndexStructureType.PARAGRAPH_INDEX, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", index_struct_dict: dict | None = None, @@ -494,7 +494,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique="high_quality" + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -535,7 +535,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -568,7 +568,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -591,7 +591,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="high_quality" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.HIGH_QUALITY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -616,7 +616,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="parent_child_model", indexing_technique="economy" + doc_form="parent_child_model", indexing_technique=IndexTechniqueType.ECONOMY ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -669,7 +669,7 @@ class TestVectorService: store when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -695,7 +695,7 @@ class TestVectorService: index when using economy indexing with keywords. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -731,7 +731,7 @@ class TestVectorService: index when using economy indexing without keywords. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -895,7 +895,7 @@ class TestVectorService: when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -923,7 +923,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -951,7 +951,7 @@ class TestVectorService: when there are new chunks, updated chunks, and deleted chunks. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock(chunk_id="new-chunk-1") @@ -993,7 +993,7 @@ class TestVectorService: add_texts is called, not delete_by_ids. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1019,7 +1019,7 @@ class TestVectorService: delete_by_ids is called, not add_texts. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) delete_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1045,7 +1045,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) new_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1075,7 +1075,7 @@ class TestVectorService: when using high_quality indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="high_quality") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.HIGH_QUALITY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() @@ -1099,7 +1099,7 @@ class TestVectorService: using economy indexing. """ # Arrange - dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique="economy") + dataset = VectorServiceTestDataFactory.create_dataset_mock(indexing_technique=IndexTechniqueType.ECONOMY) child_chunk = VectorServiceTestDataFactory.create_child_chunk_mock() diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index c0a4d2f113..936a10d6c5 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -16,7 +16,7 @@ from unittest.mock import MagicMock, patch import pytest -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.enums import DataSourceType from tasks.clean_dataset_task import clean_dataset_task @@ -184,7 +184,7 @@ class TestErrorHandling: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, doc_form=IndexStructureType.PARAGRAPH_INDEX, @@ -229,7 +229,7 @@ class TestPipelineAndWorkflowDeletion: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, doc_form=IndexStructureType.PARAGRAPH_INDEX, @@ -265,7 +265,7 @@ class TestPipelineAndWorkflowDeletion: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, doc_form=IndexStructureType.PARAGRAPH_INDEX, @@ -321,7 +321,7 @@ class TestSegmentAttachmentCleanup: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, doc_form=IndexStructureType.PARAGRAPH_INDEX, @@ -366,7 +366,7 @@ class TestSegmentAttachmentCleanup: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, doc_form=IndexStructureType.PARAGRAPH_INDEX, @@ -408,7 +408,7 @@ class TestEdgeCases: clean_dataset_task( dataset_id=dataset_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, doc_form=IndexStructureType.PARAGRAPH_INDEX, @@ -445,7 +445,7 @@ class TestIndexProcessorParameters: - Dataset object with correct attributes is passed """ # Arrange - indexing_technique = "high_quality" + indexing_technique = IndexTechniqueType.HIGH_QUALITY index_struct = '{"type": "paragraph"}' # Act diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 027cd3b1ec..0b189ebae2 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -15,7 +15,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.indexing_runner import DocumentIsPausedError -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client @@ -209,7 +209,7 @@ def mock_dataset(dataset_id, tenant_id): dataset = Mock(spec=Dataset) dataset.id = dataset_id dataset.tenant_id = tenant_id - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset From a3855eca8ba0bca2776ef614aad1654ad443e95c Mon Sep 17 00:00:00 2001 From: Desel72 Date: Tue, 24 Mar 2026 22:42:41 -0500 Subject: [PATCH 5/8] test: migrate webapp auth service tests to testcontainers (#34037) --- .../services/test_webapp_auth_service.py | 379 ------------------ 1 file changed, 379 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_webapp_auth_service.py diff --git a/api/tests/unit_tests/services/test_webapp_auth_service.py b/api/tests/unit_tests/services/test_webapp_auth_service.py deleted file mode 100644 index 262c1f1524..0000000000 --- a/api/tests/unit_tests/services/test_webapp_auth_service.py +++ /dev/null @@ -1,379 +0,0 @@ -from __future__ import annotations - -from datetime import UTC, datetime -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture -from werkzeug.exceptions import NotFound, Unauthorized - -from models import Account, AccountStatus -from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError -from services.webapp_auth_service import WebAppAuthService, WebAppAuthType - -ACCOUNT_LOOKUP_PATH = "services.webapp_auth_service.AccountService.get_account_by_email_with_case_fallback" -TOKEN_GENERATE_PATH = "services.webapp_auth_service.TokenManager.generate_token" -TOKEN_GET_DATA_PATH = "services.webapp_auth_service.TokenManager.get_token_data" - - -def _account(**kwargs: Any) -> Account: - return cast(Account, SimpleNamespace(**kwargs)) - - -@pytest.fixture -def mock_db(mocker: MockerFixture) -> MagicMock: - # Arrange - mocked_db = mocker.patch("services.webapp_auth_service.db") - mocked_db.session = MagicMock() - return mocked_db - - -def test_authenticate_should_raise_account_not_found_when_email_does_not_exist(mocker: MockerFixture) -> None: - # Arrange - mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None) - - # Act + Assert - with pytest.raises(AccountNotFoundError): - WebAppAuthService.authenticate("user@example.com", "pwd") - - -def test_authenticate_should_raise_account_login_error_when_account_is_banned(mocker: MockerFixture) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.BANNED, password="hash", password_salt="salt") - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - - # Act + Assert - with pytest.raises(AccountLoginError, match="Account is banned"): - WebAppAuthService.authenticate("user@example.com", "pwd") - - -@pytest.mark.parametrize("password_value", [None, "hash"]) -def test_authenticate_should_raise_password_error_when_password_is_invalid( - password_value: str | None, - mocker: MockerFixture, -) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.ACTIVE, password=password_value, password_salt="salt") - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - mocker.patch("services.webapp_auth_service.compare_password", return_value=False) - - # Act + Assert - with pytest.raises(AccountPasswordError, match="Invalid email or password"): - WebAppAuthService.authenticate("user@example.com", "pwd") - - -def test_authenticate_should_return_account_when_credentials_are_valid(mocker: MockerFixture) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.ACTIVE, password="hash", password_salt="salt") - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - mocker.patch("services.webapp_auth_service.compare_password", return_value=True) - - # Act - result = WebAppAuthService.authenticate("user@example.com", "pwd") - - # Assert - assert result is account - - -def test_login_should_return_token_from_internal_token_builder(mocker: MockerFixture) -> None: - # Arrange - account = _account(id="a1", email="u@example.com") - mock_get_token = mocker.patch.object(WebAppAuthService, "_get_account_jwt_token", return_value="jwt-token") - - # Act - result = WebAppAuthService.login(account) - - # Assert - assert result == "jwt-token" - mock_get_token.assert_called_once_with(account=account) - - -def test_get_user_through_email_should_return_none_when_account_not_found(mocker: MockerFixture) -> None: - # Arrange - mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None) - - # Act - result = WebAppAuthService.get_user_through_email("missing@example.com") - - # Assert - assert result is None - - -def test_get_user_through_email_should_raise_unauthorized_when_account_banned(mocker: MockerFixture) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.BANNED) - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - - # Act + Assert - with pytest.raises(Unauthorized, match="Account is banned"): - WebAppAuthService.get_user_through_email("user@example.com") - - -def test_get_user_through_email_should_return_account_when_active(mocker: MockerFixture) -> None: - # Arrange - account = SimpleNamespace(status=AccountStatus.ACTIVE) - mocker.patch( - ACCOUNT_LOOKUP_PATH, - return_value=account, - ) - - # Act - result = WebAppAuthService.get_user_through_email("user@example.com") - - # Assert - assert result is account - - -def test_send_email_code_login_email_should_raise_error_when_email_not_provided() -> None: - # Arrange - # Act + Assert - with pytest.raises(ValueError, match="Email must be provided"): - WebAppAuthService.send_email_code_login_email(account=None, email=None) - - -def test_send_email_code_login_email_should_generate_token_and_send_mail_for_account( - mocker: MockerFixture, -) -> None: - # Arrange - account = _account(email="user@example.com") - mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[1, 2, 3, 4, 5, 6]) - mock_generate_token = mocker.patch(TOKEN_GENERATE_PATH, return_value="token-1") - mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay") - - # Act - result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US") - - # Assert - assert result == "token-1" - mock_generate_token.assert_called_once() - assert mock_generate_token.call_args.kwargs["additional_data"] == {"code": "123456"} - mock_delay.assert_called_once_with(language="en-US", to="user@example.com", code="123456") - - -def test_send_email_code_login_email_should_send_mail_for_email_without_account( - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[0, 0, 0, 0, 0, 0]) - mocker.patch(TOKEN_GENERATE_PATH, return_value="token-2") - mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay") - - # Act - result = WebAppAuthService.send_email_code_login_email(account=None, email="alt@example.com", language="zh-Hans") - - # Assert - assert result == "token-2" - mock_delay.assert_called_once_with(language="zh-Hans", to="alt@example.com", code="000000") - - -def test_get_email_code_login_data_should_delegate_to_token_manager(mocker: MockerFixture) -> None: - # Arrange - mock_get_data = mocker.patch(TOKEN_GET_DATA_PATH, return_value={"code": "123"}) - - # Act - result = WebAppAuthService.get_email_code_login_data("token-abc") - - # Assert - assert result == {"code": "123"} - mock_get_data.assert_called_once_with("token-abc", "email_code_login") - - -def test_revoke_email_code_login_token_should_delegate_to_token_manager(mocker: MockerFixture) -> None: - # Arrange - mock_revoke = mocker.patch("services.webapp_auth_service.TokenManager.revoke_token") - - # Act - WebAppAuthService.revoke_email_code_login_token("token-xyz") - - # Assert - mock_revoke.assert_called_once_with("token-xyz", "email_code_login") - - -def test_create_end_user_should_raise_not_found_when_site_does_not_exist(mock_db: MagicMock) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(NotFound, match="Site not found"): - WebAppAuthService.create_end_user("app-code", "user@example.com") - - -def test_create_end_user_should_raise_not_found_when_app_does_not_exist(mock_db: MagicMock) -> None: - # Arrange - site = SimpleNamespace(app_id="app-1") - app_query = MagicMock() - app_query.where.return_value.first.return_value = None - mock_db.session.query.return_value.where.return_value.first.side_effect = [site, None] - - # Act + Assert - with pytest.raises(NotFound, match="App not found"): - WebAppAuthService.create_end_user("app-code", "user@example.com") - - -def test_create_end_user_should_create_and_commit_end_user_when_data_is_valid(mock_db: MagicMock) -> None: - # Arrange - site = SimpleNamespace(app_id="app-1") - app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") - mock_db.session.query.return_value.where.return_value.first.side_effect = [site, app_model] - - # Act - result = WebAppAuthService.create_end_user("app-code", "user@example.com") - - # Assert - assert result.tenant_id == "tenant-1" - assert result.app_id == "app-1" - assert result.session_id == "user@example.com" - mock_db.session.add.assert_called_once() - mock_db.session.commit.assert_called_once() - - -def test_get_account_jwt_token_should_build_payload_and_issue_token(mocker: MockerFixture) -> None: - # Arrange - account = _account(id="a1", email="user@example.com") - mocker.patch("services.webapp_auth_service.dify_config.ACCESS_TOKEN_EXPIRE_MINUTES", 60) - mock_issue = mocker.patch("services.webapp_auth_service.PassportService.issue", return_value="jwt-1") - - # Act - token = WebAppAuthService._get_account_jwt_token(account) - - # Assert - assert token == "jwt-1" - payload = mock_issue.call_args.args[0] - assert payload["user_id"] == "a1" - assert payload["session_id"] == "user@example.com" - assert payload["token_source"] == "webapp_login_token" - assert payload["auth_type"] == "internal" - assert payload["exp"] > int(datetime.now(UTC).timestamp()) - - -@pytest.mark.parametrize( - ("access_mode", "expected"), - [ - ("private", True), - ("private_all", True), - ("public", False), - ], -) -def test_is_app_require_permission_check_should_use_access_mode_when_provided( - access_mode: str, - expected: bool, -) -> None: - # Arrange - # Act - result = WebAppAuthService.is_app_require_permission_check(access_mode=access_mode) - - # Assert - assert result is expected - - -def test_is_app_require_permission_check_should_raise_when_no_identifier_provided() -> None: - # Arrange - # Act + Assert - with pytest.raises(ValueError, match="Either app_code or app_id must be provided"): - WebAppAuthService.is_app_require_permission_check() - - -def test_is_app_require_permission_check_should_raise_when_app_id_cannot_be_determined(mocker: MockerFixture) -> None: - # Arrange - mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value=None) - - # Act + Assert - with pytest.raises(ValueError, match="App ID could not be determined"): - WebAppAuthService.is_app_require_permission_check(app_code="app-code") - - -def test_is_app_require_permission_check_should_return_true_when_enterprise_mode_requires_it( - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1") - mocker.patch( - "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", - return_value=SimpleNamespace(access_mode="private"), - ) - - # Act - result = WebAppAuthService.is_app_require_permission_check(app_code="app-code") - - # Assert - assert result is True - - -def test_is_app_require_permission_check_should_return_false_when_enterprise_settings_do_not_require_it( - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", - return_value=SimpleNamespace(access_mode="public"), - ) - - # Act - result = WebAppAuthService.is_app_require_permission_check(app_id="app-1") - - # Assert - assert result is False - - -@pytest.mark.parametrize( - ("access_mode", "expected"), - [ - ("public", WebAppAuthType.PUBLIC), - ("private", WebAppAuthType.INTERNAL), - ("private_all", WebAppAuthType.INTERNAL), - ("sso_verified", WebAppAuthType.EXTERNAL), - ], -) -def test_get_app_auth_type_should_map_access_modes_correctly( - access_mode: str, - expected: WebAppAuthType, -) -> None: - # Arrange - # Act - result = WebAppAuthService.get_app_auth_type(access_mode=access_mode) - - # Assert - assert result == expected - - -def test_get_app_auth_type_should_resolve_from_app_code(mocker: MockerFixture) -> None: - # Arrange - mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1") - mocker.patch( - "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", - return_value=SimpleNamespace(access_mode="private_all"), - ) - - # Act - result = WebAppAuthService.get_app_auth_type(app_code="app-code") - - # Assert - assert result == WebAppAuthType.INTERNAL - - -def test_get_app_auth_type_should_raise_when_no_input_provided() -> None: - # Arrange - # Act + Assert - with pytest.raises(ValueError, match="Either app_code or access_mode must be provided"): - WebAppAuthService.get_app_auth_type() - - -def test_get_app_auth_type_should_raise_when_cannot_determine_type_from_invalid_mode() -> None: - # Arrange - # Act + Assert - with pytest.raises(ValueError, match="Could not determine app authentication type"): - WebAppAuthService.get_app_auth_type(access_mode="unknown") From b4e541e11a0571eaf9820d00971956bcc4dc1cfe Mon Sep 17 00:00:00 2001 From: Desel72 Date: Tue, 24 Mar 2026 22:45:13 -0500 Subject: [PATCH 6/8] test: migrate advanced prompt template service tests to testcontainers (#34034) --- .../test_advanced_prompt_template_service.py | 214 ------------------ 1 file changed, 214 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_advanced_prompt_template_service.py diff --git a/api/tests/unit_tests/services/test_advanced_prompt_template_service.py b/api/tests/unit_tests/services/test_advanced_prompt_template_service.py deleted file mode 100644 index a6bc79e82b..0000000000 --- a/api/tests/unit_tests/services/test_advanced_prompt_template_service.py +++ /dev/null @@ -1,214 +0,0 @@ -""" -Unit tests for services.advanced_prompt_template_service -""" - -import copy - -from core.prompt.prompt_templates.advanced_prompt_templates import ( - BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, - BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, - BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, - BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, - BAICHUAN_CONTEXT, - CHAT_APP_CHAT_PROMPT_CONFIG, - CHAT_APP_COMPLETION_PROMPT_CONFIG, - COMPLETION_APP_CHAT_PROMPT_CONFIG, - COMPLETION_APP_COMPLETION_PROMPT_CONFIG, - CONTEXT, -) -from models.model import AppMode -from services.advanced_prompt_template_service import AdvancedPromptTemplateService - - -class TestAdvancedPromptTemplateService: - """Test suite for AdvancedPromptTemplateService.""" - - def test_get_prompt_should_use_baichuan_prompt_when_model_name_contains_baichuan(self) -> None: - """Test baichuan model names use baichuan context prompt.""" - # Arrange - args = { - "app_mode": AppMode.CHAT, - "model_mode": "chat", - "model_name": "Baichuan2-13B", - "has_context": "true", - } - - # Act - result = AdvancedPromptTemplateService.get_prompt(args) - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) - - def test_get_prompt_should_use_common_prompt_when_model_name_not_baichuan(self) -> None: - """Test non-baichuan model names use common prompt.""" - # Arrange - args = { - "app_mode": AppMode.CHAT, - "model_mode": "completion", - "model_name": "gpt-4", - "has_context": "false", - } - original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_prompt(args) - - # Assert - assert result == original_config - assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_return_empty_dict_when_app_mode_invalid(self) -> None: - """Test invalid app mode returns empty dict.""" - # Arrange - app_mode = "invalid" - model_mode = "chat" - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "true") - - # Assert - assert result == {} - - def test_get_common_prompt_should_prepend_context_for_completion_prompt(self) -> None: - """Test context is prepended for completion prompt when has_context is true.""" - # Arrange - original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true") - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"].startswith(CONTEXT) - assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_prepend_context_for_chat_prompt(self) -> None: - """Test context is prepended for chat prompt when has_context is true.""" - # Arrange - original_config = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true") - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(CONTEXT) - assert original_config == COMPLETION_APP_CHAT_PROMPT_CONFIG - - def test_get_common_prompt_should_return_chat_prompt_without_context_when_has_context_false(self) -> None: - """Test chat prompt remains unchanged when has_context is false.""" - # Arrange - original_config = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "false") - - # Assert - assert result == original_config - assert original_config == CHAT_APP_CHAT_PROMPT_CONFIG - - def test_get_common_prompt_should_return_completion_prompt_for_completion_app_mode(self) -> None: - """Test completion app mode with completion model returns completion prompt.""" - # Arrange - original_config = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "false") - - # Assert - assert result == original_config - assert original_config == COMPLETION_APP_COMPLETION_PROMPT_CONFIG - - def test_get_common_prompt_should_return_empty_dict_when_model_mode_invalid(self) -> None: - """Test invalid model mode returns empty dict.""" - # Arrange - app_mode = AppMode.CHAT - model_mode = "invalid" - - # Act - result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "false") - - # Assert - assert result == {} - - def test_get_completion_prompt_should_not_prepend_context_when_has_context_false(self) -> None: - """Test helper keeps completion prompt unchanged when context is disabled.""" - # Arrange - prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) - original_text = prompt_template["completion_prompt_config"]["prompt"]["text"] - - # Act - result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT) - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"] == original_text - - def test_get_chat_prompt_should_not_prepend_context_when_has_context_false(self) -> None: - """Test helper keeps chat prompt unchanged when context is disabled.""" - # Arrange - prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) - original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"] - - # Act - result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT) - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"] == original_text - - def test_get_baichuan_prompt_should_return_chat_completion_config_when_chat_completion(self) -> None: - """Test baichuan chat/completion returns the expected config.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false") - - # Assert - assert result == original_config - assert original_config == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_completion_chat_config_when_completion_chat(self) -> None: - """Test baichuan completion/chat returns the expected config.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "false") - - # Assert - assert result == original_config - assert original_config == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_completion_completion_config_when_enabled_context(self) -> None: - """Test baichuan completion/completion prepends baichuan context when enabled.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true") - - # Assert - assert result["completion_prompt_config"]["prompt"]["text"].startswith(BAICHUAN_CONTEXT) - assert original_config == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_chat_chat_config_when_enabled_context(self) -> None: - """Test baichuan chat/chat prepends baichuan context when enabled.""" - # Arrange - original_config = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG) - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true") - - # Assert - assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) - assert original_config == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG - - def test_get_baichuan_prompt_should_return_empty_dict_when_invalid_inputs(self) -> None: - """Test invalid baichuan mode combinations return empty dict.""" - # Arrange - app_mode = "invalid" - model_mode = "invalid" - - # Act - result = AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode, "true") - - # Assert - assert result == {} From 4c32acf857e53feadf7ad12019f7e45838efc372 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Wed, 25 Mar 2026 04:46:22 +0100 Subject: [PATCH 7/8] refactor: select in console datasets segments and API key controllers (#34027) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/datasets/datasets.py | 42 ++++++++------ .../console/datasets/datasets_segments.py | 56 +++++++++---------- .../console/datasets/test_datasets.py | 31 ++++------ .../datasets/test_datasets_segments.py | 50 ++++++++--------- 4 files changed, 86 insertions(+), 93 deletions(-) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index fb98932269..c79b377bb2 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -3,7 +3,7 @@ from typing import Any, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select +from sqlalchemy import func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -738,20 +738,23 @@ class DatasetIndexingStatusApi(Resource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -802,9 +805,12 @@ class DatasetApiKeyApi(Resource): _, current_tenant_id = current_account_with_tenant() current_key_count = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id) - .count() + db.session.scalar( + select(func.count(ApiToken.id)).where( + ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id + ) + ) + or 0 ) if current_key_count >= self.max_keys: @@ -839,14 +845,14 @@ class DatasetApiDeleteApi(Resource): def delete(self, api_key_id): _, current_tenant_id = current_account_with_tenant() api_key_id = str(api_key_id) - key = ( - db.session.query(ApiToken) + key = db.session.scalar( + select(ApiToken) .where( ApiToken.tenant_id == current_tenant_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, ) - .first() + .limit(1) ) if key is None: @@ -857,7 +863,7 @@ class DatasetApiDeleteApi(Resource): assert key is not None # nosec - for type checker only ApiTokenCache.delete(key.token, key.type) - db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() + db.session.delete(key) db.session.commit() return {"result": "success"}, 204 diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index fa9bc7f159..5c0c93e3ba 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -401,10 +401,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -447,10 +447,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -494,7 +494,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): payload = BatchImportPayload.model_validate(console_ns.payload or {}) upload_file_id = payload.upload_file_id - upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1)) if not upload_file: raise NotFound("UploadFile not found.") @@ -559,10 +559,10 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -616,10 +616,10 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -666,10 +666,10 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -714,24 +714,24 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ( - db.session.query(ChildChunk) + child_chunk = db.session.scalar( + select(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) - .first() + .limit(1) ) if not child_chunk: raise NotFound("Child chunk not found.") @@ -771,24 +771,24 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ( - db.session.query(ChildChunk) + child_chunk = db.session.scalar( + select(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) - .first() + .limit(1) ) if not child_chunk: raise NotFound("Child chunk not found.") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 68a7b30b9e..ff565f19fd 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -1476,8 +1476,8 @@ class TestDatasetIndexingStatusApi: return_value=MagicMock(all=lambda: [document]), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=3, ), ): response, status = method(api, "dataset-1") @@ -1526,13 +1526,6 @@ class TestDatasetIndexingStatusApi: document.error = None document.stopped_at = None - # First count = completed segments, second = total segments - query_mock = MagicMock() - query_mock.where.side_effect = [ - MagicMock(count=lambda: 2), - MagicMock(count=lambda: 5), - ] - with ( app.test_request_context("/"), patch( @@ -1544,8 +1537,8 @@ class TestDatasetIndexingStatusApi: return_value=MagicMock(all=lambda: [document]), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=query_mock, + "controllers.console.datasets.datasets.db.session.scalar", + side_effect=[2, 5], ), ): response, status = method(api, "dataset-1") @@ -1591,8 +1584,8 @@ class TestDatasetApiKeyApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=3, ), patch( "controllers.console.datasets.datasets.ApiToken.generate_api_key", @@ -1625,8 +1618,8 @@ class TestDatasetApiKeyApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=10, ), ): with pytest.raises(BadRequest) as exc_info: @@ -1653,8 +1646,8 @@ class TestDatasetApiDeleteApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=mock_key, ), patch( "controllers.console.datasets.datasets.db.session.commit", @@ -1681,8 +1674,8 @@ class TestDatasetApiDeleteApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index 1482499c41..306a772fd1 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -526,8 +526,8 @@ class TestDatasetDocumentSegmentUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -621,8 +621,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.redis_client.setnx", @@ -706,8 +706,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): @@ -738,8 +738,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), ): with pytest.raises(ValueError): @@ -770,8 +770,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.redis_client.setnx", @@ -831,8 +831,8 @@ class TestChildChunkAddApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -880,8 +880,8 @@ class TestChildChunkAddApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -924,11 +924,8 @@ class TestChildChunkUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - side_effect=[ - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), - ], + "controllers.console.datasets.datasets_segments.db.session.scalar", + side_effect=[segment, child_chunk], ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -970,11 +967,8 @@ class TestChildChunkUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - side_effect=[ - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), - ], + "controllers.console.datasets.datasets_segments.db.session.scalar", + side_effect=[segment, child_chunk], ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -1180,8 +1174,8 @@ class TestSegmentOperationCases: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), ): with pytest.raises(NotFound): @@ -1215,8 +1209,8 @@ class TestSegmentOperationCases: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", From d87263f7c3b3ef1309d3d6f40fa73eb32e89c5a1 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Wed, 25 Mar 2026 04:47:25 +0100 Subject: [PATCH 8/8] refactor: select in console datasets document controller (#34029) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/commands/vector.py | 8 ++- api/controllers/console/datasets/datasets.py | 9 +-- .../console/datasets/datasets_document.py | 7 ++- .../console/datasets/datasets_segments.py | 9 +-- .../service_api/dataset/dataset.py | 16 ++++-- .../service_api/dataset/segment.py | 9 +-- .../annotation_reply/annotation_reply.py | 3 +- api/core/indexing_runner.py | 20 +++---- api/core/rag/docstore/dataset_docstore.py | 3 +- .../rag/index_processor/index_processor.py | 3 +- .../processor/paragraph_index_processor.py | 10 ++-- .../processor/parent_child_index_processor.py | 8 +-- .../processor/qa_index_processor.py | 6 +- api/core/rag/retrieval/dataset_retrieval.py | 6 +- api/core/rag/summary_index/summary_index.py | 3 +- .../dataset_multi_retriever_tool.py | 3 +- .../dataset_retriever_tool.py | 5 +- api/models/dataset.py | 4 +- api/services/dataset_service.py | 55 ++++++++++--------- .../rag_pipeline/rag_pipeline_dsl_service.py | 17 +++--- .../rag_pipeline_transform_service.py | 18 +++--- api/services/summary_index_service.py | 13 +++-- api/services/vector_service.py | 12 ++-- .../add_annotation_to_index_task.py | 3 +- .../batch_import_annotations_task.py | 3 +- .../delete_annotation_index_task.py | 3 +- .../disable_annotation_reply_task.py | 3 +- .../enable_annotation_reply_task.py | 5 +- .../update_annotation_to_index_task.py | 3 +- .../batch_create_segment_to_index_task.py | 4 +- api/tasks/document_indexing_task.py | 4 +- api/tasks/generate_summary_index_task.py | 3 +- api/tasks/regenerate_summary_index_task.py | 4 +- .../test_dataset_retrieval_integration.py | 6 +- .../services/dataset_service_update_delete.py | 3 +- .../test_dataset_permission_service.py | 3 +- .../services/test_dataset_service.py | 24 ++++---- .../test_dataset_service_delete_dataset.py | 6 +- .../test_dataset_service_get_segments.py | 3 +- .../test_dataset_service_retrieval.py | 3 +- .../test_dataset_service_update_dataset.py | 39 ++++++------- .../services/test_tag_service.py | 3 +- .../tasks/test_add_document_to_index_task.py | 4 +- ...test_batch_create_segment_to_index_task.py | 4 +- .../tasks/test_clean_dataset_task.py | 6 +- .../test_create_segment_to_index_task.py | 4 +- .../tasks/test_dataset_indexing_task.py | 3 +- .../test_delete_segment_from_index_task.py | 4 +- .../test_disable_segment_from_index_task.py | 4 +- .../test_disable_segments_from_index_task.py | 4 +- .../tasks/test_document_indexing_sync_task.py | 4 +- .../tasks/test_document_indexing_task.py | 5 +- .../test_document_indexing_update_task.py | 4 +- .../test_duplicate_document_indexing_task.py | 6 +- .../test_enable_segments_to_index_task.py | 4 +- 55 files changed, 233 insertions(+), 195 deletions(-) diff --git a/api/commands/vector.py b/api/commands/vector.py index bef18bf73b..cb7eb7c452 100644 --- a/api/commands/vector.py +++ b/api/commands/vector.py @@ -10,7 +10,7 @@ from configs import dify_config from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment @@ -86,7 +86,7 @@ def migrate_annotation_vector_database(): dataset = Dataset( id=app.id, tenant_id=app.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, @@ -178,7 +178,9 @@ def migrate_knowledge_vector_database(): while True: try: stmt = ( - select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) + select(Dataset) + .where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY) + .order_by(Dataset.created_at.desc()) ) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index c79b377bb2..27c772fbe0 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -29,6 +29,7 @@ from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db @@ -355,7 +356,7 @@ class DatasetListApi(Resource): for item in data: # convert embedding_model_provider to plugin standard format - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: + if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]: item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: @@ -436,7 +437,7 @@ class DatasetApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: provider_id = ModelProviderID(dataset.embedding_model_provider) data["embedding_model_provider"] = str(provider_id) @@ -454,7 +455,7 @@ class DatasetApi(Resource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data["indexing_technique"] == "high_quality": + if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: data["embedding_available"] = True @@ -485,7 +486,7 @@ class DatasetApi(Resource): current_user, current_tenant_id = current_account_with_tenant() # check embedding model setting if ( - payload.indexing_technique == "high_quality" + payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY and payload.embedding_model_provider is not None and payload.embedding_model is not None ): diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 074694e7ea..897724182f 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -27,6 +27,7 @@ from core.model_manager import ModelManager from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db @@ -449,7 +450,7 @@ class DatasetInitApi(Resource): raise Forbidden() knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {}) - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: @@ -463,7 +464,7 @@ class DatasetInitApi(Resource): is_multimodal = DatasetService.check_is_multimodal_model( current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model ) - knowledge_config.is_multimodal = is_multimodal + knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment] except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." @@ -1337,7 +1338,7 @@ class DocumentGenerateSummaryApi(Resource): raise BadRequest("document_list cannot be empty.") # Check if dataset configuration supports summary generation - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: raise ValueError( f"Summary generation is only available for 'high_quality' indexing technique. " f"Current indexing technique: {dataset.indexing_technique}" diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 5c0c93e3ba..7333fcaa07 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -26,6 +26,7 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -279,7 +280,7 @@ class DatasetDocumentSegmentApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: model_manager = ModelManager() @@ -333,7 +334,7 @@ class DatasetDocumentSegmentAddApi(Resource): if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -383,7 +384,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: model_manager = ModelManager() @@ -569,7 +570,7 @@ class ChildChunkAddApi(Resource): if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 89be847cd3..25b6436a71 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -15,6 +15,7 @@ from controllers.service_api.wraps import ( cloud_edition_billing_rate_limit_check, ) from core.provider_manager import ProviderManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag @@ -153,9 +154,14 @@ class DatasetListApi(DatasetApiResource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: # type: ignore - item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) # type: ignore - item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # type: ignore + if ( + item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index] + and item["embedding_model_provider"] # pyrefly: ignore[bad-index] + ): + item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation] + ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index] + ) + item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index] if item_model in model_names: item["embedding_available"] = True # type: ignore else: @@ -265,7 +271,7 @@ class DatasetApi(DatasetApiResource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data.get("indexing_technique") == "high_quality": + if data.get("indexing_technique") == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}" if item_model in model_names: data["embedding_available"] = True @@ -315,7 +321,7 @@ class DatasetApi(DatasetApiResource): # check embedding model setting embedding_model_provider = payload.embedding_model_provider embedding_model = payload.embedding_model - if payload.indexing_technique == "high_quality" or embedding_model_provider: + if payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY or embedding_model_provider: if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting( dataset.tenant_id, embedding_model_provider, embedding_model diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 2e3b7fd85e..595b01a9f2 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -17,6 +17,7 @@ from controllers.service_api.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields @@ -103,7 +104,7 @@ class SegmentApi(DatasetApiResource): if not document.enabled: raise NotFound("Document is disabled.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -157,7 +158,7 @@ class SegmentApi(DatasetApiResource): if not document: raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -262,7 +263,7 @@ class DatasetSegmentApi(DatasetApiResource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: model_manager = ModelManager() @@ -358,7 +359,7 @@ class ChildChunkApi(DatasetApiResource): raise NotFound("Segment not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 87d4772815..0bd904811a 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -4,6 +4,7 @@ from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models.dataset import Dataset from models.enums import CollectionBindingType, ConversationFromSource @@ -50,7 +51,7 @@ class AnnotationReplyFeature: dataset = Dataset( id=app_record.id, tenant_id=app_record.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 52776ee626..06bc366081 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document @@ -271,7 +271,7 @@ class IndexingRunner: doc_form: str | None = None, doc_language: str = "English", dataset_id: str | None = None, - indexing_technique: str = "economy", + indexing_technique: str = IndexTechniqueType.ECONOMY, ) -> IndexingEstimate: """ Estimate the indexing for the document. @@ -289,7 +289,7 @@ class IndexingRunner: dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise ValueError("Dataset not found.") - if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": + if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}: if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=tenant_id, @@ -303,7 +303,7 @@ class IndexingRunner: model_type=ModelType.TEXT_EMBEDDING, ) else: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: embedding_model_instance = self.model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, @@ -573,7 +573,7 @@ class IndexingRunner: """ embedding_model_instance = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -587,7 +587,7 @@ class IndexingRunner: create_keyword_thread = None if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY ): # create keyword index create_keyword_thread = threading.Thread( @@ -597,7 +597,7 @@ class IndexingRunner: create_keyword_thread.start() max_workers = 10 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] @@ -628,7 +628,7 @@ class IndexingRunner: tokens += future.result() if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY and create_keyword_thread is not None ): create_keyword_thread.join() @@ -654,7 +654,7 @@ class IndexingRunner: raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: document_ids = [document.metadata["doc_id"] for document in documents] db.session.query(DocumentSegment).where( DocumentSegment.document_id == document_id, @@ -764,7 +764,7 @@ class IndexingRunner: ) -> list[Document]: # get embedding model instance embedding_model_instance = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 16a5588024..cd27113245 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -6,6 +6,7 @@ from typing import Any from sqlalchemy import func, select from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db @@ -71,7 +72,7 @@ class DatasetDocumentStore: if max_position is None: max_position = 0 embedding_model = None - if self._dataset.indexing_technique == "high_quality": + if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index d9145023ac..a6d1db214b 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -9,6 +9,7 @@ from flask import current_app from sqlalchemy import delete, func, select from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview @@ -159,7 +160,7 @@ class IndexProcessor: tenant_id = dataset.tenant_id preview_output = self.format_preview(chunk_structure, chunks) - if indexing_technique != "high_quality": + if indexing_technique != IndexTechniqueType.HIGH_QUALITY: return preview_output if not summary_index_setting or not summary_index_setting.get("enable"): diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 80163b1707..726cc062f6 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -22,7 +22,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -117,7 +117,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -155,7 +155,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -253,12 +253,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if all_multimodal_documents and dataset.is_multimodal: vector.create_multimodal(all_multimodal_documents) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: keyword = Keyword(dataset) keyword.add_texts(documents) diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index df0761ca73..70504e6e50 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -18,7 +18,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) for document in documents: child_documents = document.children @@ -166,7 +166,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: delete_child_chunks = kwargs.get("delete_child_chunks") or False precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") vector = Vector(dataset) @@ -332,7 +332,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=True) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: all_child_documents = [] all_multimodal_documents = [] for doc in documents: diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 62f88b7760..6874603a83 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -21,7 +21,7 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -141,7 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -224,7 +224,7 @@ class QAIndexProcessor(BaseIndexProcessor): # save node to document segment doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) else: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 78a97f79a5..52061fd93d 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -675,7 +675,7 @@ class DatasetRetrieval: # get top k top_k = retrieval_model_config["top_k"] # get retrieval method - if selected_dataset.indexing_technique == "economy": + if selected_dataset.indexing_technique == IndexTechniqueType.ECONOMY: retrieval_method = RetrievalMethod.KEYWORD_SEARCH else: retrieval_method = retrieval_model_config["search_method"] @@ -752,7 +752,7 @@ class DatasetRetrieval: "The configured knowledge base list have different indexing technique, please set reranking model." ) index_type = available_datasets[0].indexing_technique - if index_type == "high_quality": + if index_type == IndexTechniqueType.HIGH_QUALITY: embedding_model_check = all( item.embedding_model == available_datasets[0].embedding_model for item in available_datasets ) @@ -1068,7 +1068,7 @@ class DatasetRetrieval: else default_retrieval_model ) - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py index 31d21dbeee..6f120bd471 100644 --- a/api/core/rag/summary_index/summary_index.py +++ b/api/core/rag/summary_index/summary_index.py @@ -2,6 +2,7 @@ import concurrent.futures import logging from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary from services.summary_index_service import SummaryIndexService @@ -21,7 +22,7 @@ class SummaryIndex: if is_preview: with session_factory.create_session() as session: dataset = session.query(Dataset).filter_by(id=dataset_id).first() - if not dataset or dataset.indexing_technique != "high_quality": + if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return if summary_index_setting is None: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index c2b520fa99..75b923fd8b 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelManager from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -169,7 +170,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 429b7e6622..f3d390ed59 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -8,6 +8,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict, from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -140,7 +141,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model retrieval_resource_list: list[RetrievalSourceMetadata] = [] - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, @@ -173,7 +174,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for hit_callback in self.hit_callbacks: hit_callback.on_tool_end(documents) document_score_list = {} - if dataset.indexing_technique != "economy": + if dataset.indexing_technique != IndexTechniqueType.ECONOMY: for item in documents: if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] diff --git a/api/models/dataset.py b/api/models/dataset.py index b4fb03a7f4..e323ccfd7f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -20,7 +20,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file @@ -137,7 +137,7 @@ class Dataset(Base): default=DatasetPermissionEnum.ONLY_ME, ) data_source_type = mapped_column(EnumText(DataSourceType, length=255)) - indexing_technique: Mapped[str | None] = mapped_column(String(255)) + indexing_technique: Mapped[IndexTechniqueType | None] = mapped_column(EnumText(IndexTechniqueType, length=255)) index_struct = mapped_column(LongText, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 65e112f1e9..969ca68545 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -21,7 +21,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.file import helpers as file_helpers from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType @@ -228,7 +228,7 @@ class DatasetService: if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() if embedding_model_provider and embedding_model_name: # check if embedding model setting is valid @@ -254,7 +254,10 @@ class DatasetService: retrieval_model.reranking_model.reranking_provider_name, retrieval_model.reranking_model.reranking_model_name, ) - dataset = Dataset(name=name, indexing_technique=indexing_technique) + dataset = Dataset( + name=name, + indexing_technique=IndexTechniqueType(indexing_technique) if indexing_technique else None, + ) # dataset = Dataset(name=name, provider=provider, config=config) dataset.description = description dataset.created_by = account.id @@ -349,7 +352,7 @@ class DatasetService: @staticmethod def check_dataset_model_setting(dataset): - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -717,13 +720,13 @@ class DatasetService: if "indexing_technique" not in data: return None if dataset.indexing_technique != data["indexing_technique"]: - if data["indexing_technique"] == "economy": + if data["indexing_technique"] == IndexTechniqueType.ECONOMY: # Remove embedding model configuration for economy mode filtered_data["embedding_model"] = None filtered_data["embedding_model_provider"] = None filtered_data["collection_binding_id"] = None return "remove" - elif data["indexing_technique"] == "high_quality": + elif data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: # Configure embedding model for high quality mode DatasetService._configure_embedding_model_for_high_quality(data, filtered_data) return "add" @@ -953,8 +956,8 @@ class DatasetService: dataset = session.merge(dataset) if not has_published: dataset.chunk_structure = knowledge_configuration.chunk_structure - dataset.indexing_technique = knowledge_configuration.indexing_technique - if knowledge_configuration.indexing_technique == "high_quality": + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, # ignore type error @@ -976,7 +979,7 @@ class DatasetService: embedding_model_name, ) dataset.collection_binding_id = dataset_collection_binding.id - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number else: raise ValueError("Invalid index method") @@ -991,9 +994,9 @@ class DatasetService: action = None if dataset.indexing_technique != knowledge_configuration.indexing_technique: # if update indexing_technique - if knowledge_configuration.indexing_technique == "economy": + if knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") - elif knowledge_configuration.indexing_technique == "high_quality": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: action = "add" # get embedding model setting try: @@ -1018,7 +1021,7 @@ class DatasetService: ) dataset.is_multimodal = is_multimodal dataset.collection_binding_id = dataset_collection_binding.id - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -1029,7 +1032,7 @@ class DatasetService: else: # add default plugin id to both setting sets, to make sure the plugin model provider is consistent # Skip embedding model checks if not provided in the update request - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: skip_embedding_update = False try: # Handle existing model provider @@ -1089,7 +1092,7 @@ class DatasetService: ) except ProviderTokenNotInitError as ex: raise ValueError(ex.description) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: if dataset.keyword_number != knowledge_configuration.keyword_number: dataset.keyword_number = knowledge_configuration.keyword_number dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() @@ -1907,8 +1910,8 @@ class DocumentService: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = knowledge_config.indexing_technique - if knowledge_config.indexing_technique == "high_quality": + dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique) + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: dataset_embedding_model = knowledge_config.embedding_model @@ -2689,7 +2692,7 @@ class DocumentService: dataset_collection_binding_id = None retrieval_model = None - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: assert knowledge_config.embedding_model_provider assert knowledge_config.embedding_model dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -2712,7 +2715,7 @@ class DocumentService: tenant_id=tenant_id, name="", data_source_type=knowledge_config.data_source.info_list.data_source_type, - indexing_technique=knowledge_config.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_config.indexing_technique), created_by=account.id, embedding_model=knowledge_config.embedding_model, embedding_model_provider=knowledge_config.embedding_model_provider, @@ -3125,7 +3128,7 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, @@ -3208,7 +3211,7 @@ class SegmentService: try: with redis_client.lock(lock_name, timeout=600): embedding_model = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, @@ -3230,7 +3233,7 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality" and embedding_model: + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY and embedding_model: # calc embedding use tokens if document.doc_form == IndexStructureType.QA_INDEX: tokens = embedding_model.get_text_embedding_num_tokens( @@ -3345,7 +3348,7 @@ class SegmentService: if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting model_manager = ModelManager() @@ -3382,7 +3385,7 @@ class SegmentService: # When user manually provides summary, allow saving even if summary_index_setting doesn't exist # summary_index_setting is only needed for LLM generation, not for manual summary vectorization # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # Query existing summary from database from models.dataset import DocumentSegmentSummary @@ -3409,7 +3412,7 @@ class SegmentService: else: segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, @@ -3449,7 +3452,7 @@ class SegmentService: db.session.commit() if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting model_manager = ModelManager() @@ -3481,7 +3484,7 @@ class SegmentService: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) # Handle summary index when content changed - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: from models.dataset import DocumentSegmentSummary existing_summary = ( diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index deb59da8d3..fd66d55c1a 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -22,6 +22,7 @@ from sqlalchemy.orm import Session from core.helper import ssrf_proxy from core.helper.name_generator import generate_incremental_name from core.plugin.entities.plugin import PluginDependency +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.datasource.entities import DatasourceNodeData from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData @@ -311,13 +312,13 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -343,7 +344,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -443,18 +444,18 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) else: - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -480,7 +481,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -772,7 +773,7 @@ class RagPipelineDslService: ) case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE: knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"]) - if knowledge_index_entity.indexing_technique == "high_quality": + if knowledge_index_entity.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_index_entity.embedding_model_provider: dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 7dcfecdd1d..215a8c8528 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -9,7 +9,7 @@ from flask_login import current_user from constants import DOCUMENT_EXTENSIONS from core.plugin.impl.plugin import PluginInstaller -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from factories import variable_factory @@ -105,29 +105,29 @@ class RagPipelineTransformService: if doc_form == IndexStructureType.PARAGRAPH_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.file-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.file-general-economy.yml with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.NOTION_IMPORT: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.notion-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.notion-general-economy.yml with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.WEBSITE_CRAWL: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.website-crawl-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.website-crawl-general-economy.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) @@ -170,11 +170,11 @@ class RagPipelineTransformService: ): knowledge_configuration_dict = node.get("data", {}) - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: knowledge_configuration.embedding_model = dataset.embedding_model knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider if retrieval_model: - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH knowledge_configuration.retrieval_model = retrieval_model else: diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 943dfc972b..ed7a33feae 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -12,6 +12,7 @@ from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document from dify_graph.model_runtime.entities.llm_entities import LLMUsage @@ -140,7 +141,7 @@ class SummaryIndexService: session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one. If not provided, creates a new session and commits automatically. """ - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.warning( "Summary vectorization skipped for dataset %s: indexing_technique is not high_quality", dataset.id, @@ -724,7 +725,7 @@ class SummaryIndexService: List of created DocumentSegmentSummary instances """ # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( "Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'", dataset.id, @@ -851,7 +852,7 @@ class SummaryIndexService: ) # Remove from vector database (but keep records) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: try: @@ -889,7 +890,7 @@ class SummaryIndexService: segment_ids: List of segment IDs to enable summaries for. If None, enable all. """ # Only enable summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return with session_factory.create_session() as session: @@ -981,7 +982,7 @@ class SummaryIndexService: return # Delete from vector database - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: vector = Vector(dataset) @@ -1012,7 +1013,7 @@ class SummaryIndexService: Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality """ # Only update summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return None # When user manually provides summary, allow saving even if summary_index_setting doesn't exist diff --git a/api/services/vector_service.py b/api/services/vector_service.py index b66fdd7a20..bb94a03ba3 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -4,7 +4,7 @@ from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document @@ -45,7 +45,7 @@ class VectorService: if not processing_rule: raise ValueError("No processing rule found.") # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting model_manager = ModelManager() @@ -112,7 +112,7 @@ class VectorService: "dataset_id": segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) vector.delete_by_ids([segment.index_node_id]) @@ -197,7 +197,7 @@ class VectorService: "dataset_id": child_segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # save vector index vector = Vector(dataset=dataset) vector.add_texts([child_document], duplicate_check=True) @@ -237,7 +237,7 @@ class VectorService: delete_node_ids.append(update_child_chunk.index_node_id) for delete_child_chunk in delete_child_chunks: delete_node_ids.append(delete_child_chunk.index_node_id) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) if delete_node_ids: @@ -252,7 +252,7 @@ class VectorService: @classmethod def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset): - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return attachments = segment.attachments diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index a9a8b892c2..dafa36cc34 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -36,7 +37,7 @@ def add_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index fc6bf03454..c734e1321b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -67,7 +68,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 432732af95..c9aa8fadb7 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -26,7 +27,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=dataset_collection_binding.id, ) diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 7b5cd46b00..41cf7ccbf6 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import exists, select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_redis import redis_client from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation @@ -44,7 +45,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=app_annotation_setting.collection_binding_id, ) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 1fe43c3d62..2c07fe0f31 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -64,7 +65,7 @@ def enable_annotation_reply_task( old_dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=old_dataset_collection_binding.provider_name, embedding_model=old_dataset_collection_binding.model_name, collection_binding_id=old_dataset_collection_binding.id, @@ -93,7 +94,7 @@ def enable_annotation_reply_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 6ff34c0e74..f41da1d373 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -37,7 +38,7 @@ def update_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 7f810129ef..dd58378e0e 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -11,7 +11,7 @@ from sqlalchemy import func from core.db.session_factory import session_factory from core.model_manager import ModelManager -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -120,7 +120,7 @@ def batch_create_segment_to_index_task( document_segments = [] embedding_model = None - if dataset_config["indexing_technique"] == "high_quality": + if dataset_config["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=dataset_config["tenant_id"], diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index b5794e33e2..23a80fa106 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -10,7 +10,7 @@ from configs import dify_config from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now @@ -127,7 +127,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): logger.warning("Dataset %s not found after indexing", dataset_id) return - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_index_setting = dataset.summary_index_setting if summary_index_setting and summary_index_setting.get("enable"): # expire all session to get latest document's indexing status diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py index 6493833edc..e3d82d2851 100644 --- a/api/tasks/generate_summary_index_task.py +++ b/api/tasks/generate_summary_index_task.py @@ -7,6 +7,7 @@ import click from celery import shared_task from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -59,7 +60,7 @@ def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: return # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary generation for dataset {dataset_id}: " diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index ac5d23408a..6f490ab7ea 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -9,7 +9,7 @@ from celery import shared_task from sqlalchemy import or_, select from core.db.session_factory import session_factory -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -53,7 +53,7 @@ def regenerate_summary_index_task( return # Only regenerate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary regeneration for dataset {dataset_id}: " diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index ea8d04502a..00d7496a40 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -4,7 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset, Document @@ -39,7 +39,7 @@ class TestGetAvailableDatasetsIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -460,7 +460,7 @@ class TestKnowledgeRetrievalIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 6b35f867d7..02c3d1a80e 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -13,6 +13,7 @@ import pytest from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum from models.enums import DataSourceType @@ -74,7 +75,7 @@ class DatasetUpdateDeleteTestDataFactory: name=name, description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index 55bfb64e18..71c8874f79 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -9,6 +9,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -69,7 +70,7 @@ class DatasetPermissionTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index c4d20bc02c..0702680f5c 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -11,7 +11,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -63,7 +63,7 @@ class DatasetServiceIntegrationDataFactory: name: str = "Test Dataset", description: str | None = "Test description", provider: str = "vendor", - indexing_technique: str | None = "high_quality", + indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY, permission: str = DatasetPermissionEnum.ONLY_ME, retrieval_model: dict | None = None, embedding_model_provider: str | None = None, @@ -157,13 +157,13 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Economy Dataset", description=None, - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "economy" + assert result.indexing_technique == IndexTechniqueType.ECONOMY assert result.embedding_model_provider is None assert result.embedding_model is None @@ -181,13 +181,13 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="High Quality Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_model.provider assert result.embedding_model == embedding_model.model_name mock_model_manager.return_value.get_default_model_instance.assert_called_once_with( @@ -273,7 +273,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Dataset With Reranking", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, retrieval_model=retrieval_model, ) @@ -306,7 +306,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Custom Embedding Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, embedding_model_provider=embedding_provider, embedding_model_name=embedding_model_name, @@ -314,7 +314,7 @@ class TestDatasetServiceCreateDataset: # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_provider assert result.embedding_model == embedding_model_name mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name) @@ -589,7 +589,7 @@ class TestDatasetServiceUpdateAndDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure="text_model", ) DatasetServiceIntegrationDataFactory.create_document( @@ -685,14 +685,14 @@ class TestDatasetServiceRetrievalConfiguration: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0}, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=str(uuid4()), ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": { "search_method": "full_text_search", "top_k": 10, diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index 807d18322c..3cac964d89 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -3,7 +3,7 @@ from unittest.mock import patch from uuid import uuid4 -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom @@ -109,7 +109,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), @@ -208,7 +208,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py index c4b3a57bb2..87239b2cb3 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py @@ -12,6 +12,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom @@ -64,7 +65,7 @@ class SegmentServiceTestDataFactory: name=f"Test Dataset {uuid4()}", description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=DatasetPermissionEnum.ONLY_ME, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 3021d8984d..2f90d16176 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -15,6 +15,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -102,7 +103,7 @@ class DatasetRetrievalTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index fd81948247..2899d5b8a5 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -4,6 +4,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings @@ -53,7 +54,7 @@ class DatasetUpdateTestDataFactory: provider: str = "vendor", name: str = "old_name", description: str = "old_description", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, retrieval_model: str = "old_model", permission: str = "only_me", embedding_model_provider: str | None = None, @@ -241,7 +242,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -250,7 +251,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": "new_description", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", @@ -261,7 +262,7 @@ class TestDatasetServiceUpdateDataset: assert dataset.name == "new_name" assert dataset.description == "new_description" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.retrieval_model == "new_model" assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" @@ -276,7 +277,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -285,7 +286,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": None, - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": None, "embedding_model": None, @@ -312,14 +313,14 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, ) update_data = { - "indexing_technique": "economy", + "indexing_technique": IndexTechniqueType.ECONOMY, "retrieval_model": "new_model", } @@ -328,7 +329,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "remove") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "economy" + assert dataset.indexing_technique == IndexTechniqueType.ECONOMY assert dataset.embedding_model is None assert dataset.embedding_model_provider is None assert dataset.collection_binding_id is None @@ -343,7 +344,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) embedding_model = Mock() @@ -354,7 +355,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", "retrieval_model": "new_model", @@ -383,7 +384,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "add") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == binding.id @@ -403,7 +404,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -411,7 +412,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", } @@ -419,7 +420,7 @@ class TestDatasetServiceUpdateDataset: db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.collection_binding_id == existing_binding_id @@ -435,7 +436,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -449,7 +450,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-3-small", "retrieval_model": "new_model", @@ -531,11 +532,11 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "invalid_provider", "embedding_model": "invalid_model", "retrieval_model": "new_model", diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 1a72e3b6c2..f504f35589 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -7,6 +7,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset from models.enums import DataSourceType, TagType @@ -102,7 +103,7 @@ class TestTagService: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, tenant_id=tenant_id, created_by=mock_external_service_dependencies["current_user"].id, ) diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 94173c34bf..4b04c1accb 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestAddDocumentToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 5ebf141828..d2e343ef52 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -19,7 +19,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -142,7 +142,7 @@ class TestBatchCreateSegmentToIndexTask: name=fake.company(), description=fake.text(), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model="text-embedding-ada-002", embedding_model_provider="openai", created_by=account.id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 9449fee0af..1dd37fbc92 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -18,7 +18,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -154,7 +154,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name="test_dataset", description="Test dataset for cleanup testing", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, @@ -870,7 +870,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name=long_name, description=long_description, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph", "max_length": 10000}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 979435282b..9f8e37fc9e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -12,7 +12,7 @@ from uuid import uuid4 import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -121,7 +121,7 @@ class TestCreateSegmentToIndexTask: description=fake.text(max_nb_chars=100), tenant_id=tenant_id, data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", created_by=account_id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 67f9dc7011..13ea94348a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -8,6 +8,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -141,7 +142,7 @@ class TestDatasetIndexingTaskIntegration: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 6fc2a53f9c..8a69707b38 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, Document, DocumentSegment, Tenant from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -108,7 +108,7 @@ class TestDeleteSegmentFromIndexTask: dataset.provider = "vendor" dataset.permission = "only_me" dataset.data_source_type = DataSourceType.UPLOAD_FILE - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.index_struct = '{"type": "paragraph"}' dataset.created_by = account.id dataset.created_at = fake.date_time_this_year() diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index d21f1daf23..5bdf7d1389 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -15,7 +15,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -100,7 +100,7 @@ class TestDisableSegmentFromIndexTask: name=fake.sentence(nb_words=3), description=fake.text(max_nb_chars=200), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index fbcb7b5264..3e9a0c8f7f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -11,7 +11,7 @@ from unittest.mock import MagicMock, patch from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument from models.dataset import DatasetProcessRule @@ -103,7 +103,7 @@ class TestDisableSegmentsFromIndexTask: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, updated_by=account.id, embedding_model="text-embedding-ada-002", diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index 10d97919fb..d4021143ef 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -14,7 +14,7 @@ from uuid import uuid4 import pytest from core.indexing_runner import DocumentIsPausedError, IndexingRunner -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -57,7 +57,7 @@ class DocumentIndexingSyncTaskTestDataFactory: name=f"dataset-{uuid4()}", description="sync test dataset", data_source_type=DataSourceType.NOTION_IMPORT, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 9421b07285..cf1a8666f3 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -5,6 +5,7 @@ import pytest from faker import Faker from core.entities.document_task import DocumentTask +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -99,7 +100,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -181,7 +182,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index c650d56091..d94abf2b40 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -64,7 +64,7 @@ class TestDocumentIndexingUpdateTask: name=fake.company(), description=fake.text(max_nb_chars=64), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index 76b6a8ae73..6a8e186958 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -110,7 +110,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -245,7 +245,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index 54b50016a8..e2f35067e3 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestEnableSegmentsToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset)