diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index f75dbe88b21..b1419b3adfa 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -24,14 +24,19 @@ from models.model import Message logger = logging.getLogger(__name__) -class BasedGenerateTaskPipeline: +class BasedGenerateTaskPipeline[AppGenerateEntityT: AppGenerateEntity]: """ BasedGenerateTaskPipeline is a class that generate stream output and state management for Application. + + The type parameter preserves the concrete application generate entity for + subclasses after the shared initializer stores it on ``_application_generate_entity``. """ + _application_generate_entity: AppGenerateEntityT + def __init__( self, - application_generate_entity: AppGenerateEntity, + application_generate_entity: AppGenerateEntityT, queue_manager: AppQueueManager, stream: bool, ): diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index bea50ea2696..a728069eede 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -65,19 +65,20 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought, Me logger = logging.getLogger(__name__) +type EasyUIAppGenerateEntity = ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity -class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): + +class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline[EasyUIAppGenerateEntity]): """ EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ _task_state: EasyUITaskState - _application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity _precomputed_event_type: StreamEvent | None = None def __init__( self, - application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity, + application_generate_entity: EasyUIAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message, @@ -310,12 +311,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): yield response case QueueLLMChunkEvent() | QueueAgentMessageEvent(): chunk = event.chunk - delta_text = chunk.delta.message.content - if delta_text is None: + delta_content = chunk.delta.message.content + if delta_content is None: continue - if isinstance(chunk.delta.message.content, list): + if isinstance(delta_content, list): + # EasyUI streams text only; structured multimodal chunks contribute their text parts. delta_text = "" - for content in chunk.delta.message.content: + for content in delta_content: logger.debug( "The content type %s in LLM chunk delta message content.: %r", type(content), content ) @@ -331,17 +333,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): content, ) continue + else: + delta_text = delta_content if not self._task_state.llm_result.prompt_messages: self._task_state.llm_result.prompt_messages = chunk.prompt_messages # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text)) + should_direct_answer = self._handle_output_moderation_chunk(delta_text) if should_direct_answer: continue current_content = cast(str, self._task_state.llm_result.message.content) - current_content += cast(str, delta_text) + current_content += delta_text self._task_state.llm_result.message.content = current_content match event: @@ -352,13 +356,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): message_id=self._message_id ) yield self._message_cycle_manager.message_to_stream_response( - answer=cast(str, delta_text), + answer=delta_text, message_id=self._message_id, event_type=self._precomputed_event_type, ) case _: yield self._agent_message_to_stream_response( - answer=cast(str, delta_text), + answer=delta_text, message_id=self._message_id, ) case QueueMessageReplaceEvent(): @@ -389,9 +393,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if not conversation: raise ValueError(f"Conversation {self._conversation_id} not found") - message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + saved_prompt = PromptMessageUtil.prompt_messages_to_prompt_for_saving( self._model_config.mode, self._task_state.llm_result.prompt_messages ) + object.__setattr__(message, "message", saved_prompt) message.message_tokens = usage.prompt_tokens message.message_unit_price = usage.prompt_unit_price message.message_price_unit = usage.prompt_price_unit diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 30b9523146f..b2073716d13 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -107,7 +107,7 @@ class LLMGenerator: tenant_id=tenant_id, model_type=ModelType.LLM, ) - prompts = [UserPromptMessage(content=prompt)] + prompts: list[PromptMessage] = [UserPromptMessage(content=prompt)] with measure_time() as timer: response: LLMResult = model_instance.invoke_llm( @@ -201,11 +201,13 @@ class LLMGenerator: except InvokeAuthorizationError: return [] - prompt_messages = [UserPromptMessage(content=prompt)] + prompt_messages: list[PromptMessage] = [UserPromptMessage(content=prompt)] questions: Sequence[str] = [] try: + model_parameters: dict[str, object] + stop: list[str] configured_completion_params = configured_model.get("completion_params") if use_configured_model and isinstance(configured_completion_params, dict): model_parameters, stop = _normalize_completion_params(configured_completion_params) @@ -253,7 +255,7 @@ class LLMGenerator: remove_template_variables=False, ) - prompt_messages = [UserPromptMessage(content=prompt_generate)] + no_variable_prompt_messages: list[PromptMessage] = [UserPromptMessage(content=prompt_generate)] model_manager = ModelManager.for_tenant(tenant_id=tenant_id) @@ -266,7 +268,7 @@ class LLMGenerator: try: response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + prompt_messages=list(no_variable_prompt_messages), model_parameters=model_parameters, stream=False ) rule_config["prompt"] = response.message.get_text_content() @@ -299,7 +301,7 @@ class LLMGenerator: }, remove_template_variables=False, ) - prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] + prompt_generate_messages: list[PromptMessage] = [UserPromptMessage(content=prompt_generate_prompt)] # get model instance model_manager = ModelManager.for_tenant(tenant_id=tenant_id) @@ -314,7 +316,7 @@ class LLMGenerator: try: # the first step to generate the task prompt prompt_content: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + prompt_messages=list(prompt_generate_messages), model_parameters=model_parameters, stream=False ) except InvokeError as e: error = str(e) @@ -331,7 +333,7 @@ class LLMGenerator: }, remove_template_variables=False, ) - parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)] + parameter_messages: list[PromptMessage] = [UserPromptMessage(content=parameter_generate_prompt)] # the second step to generate the task_parameter and task_statement statement_generate_prompt = statement_template.format( @@ -341,7 +343,7 @@ class LLMGenerator: }, remove_template_variables=False, ) - statement_messages = [UserPromptMessage(content=statement_generate_prompt)] + statement_messages: list[PromptMessage] = [UserPromptMessage(content=statement_generate_prompt)] try: parameter_content: LLMResult = model_instance.invoke_llm( @@ -397,7 +399,7 @@ class LLMGenerator: model=args.model_config_data.name, ) - prompt_messages = [UserPromptMessage(content=prompt)] + prompt_messages: list[PromptMessage] = [UserPromptMessage(content=prompt)] model_parameters = args.model_config_data.completion_params try: response: LLMResult = model_instance.invoke_llm( @@ -455,7 +457,7 @@ class LLMGenerator: model=args.model_config_data.name, ) - prompt_messages = [ + prompt_messages: list[PromptMessage] = [ SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE), UserPromptMessage(content=args.instruction), ] @@ -634,7 +636,7 @@ class LLMGenerator: system_prompt = LLM_MODIFY_CODE_SYSTEM case _: system_prompt = LLM_MODIFY_PROMPT_SYSTEM - prompt_messages = [ + prompt_messages: list[PromptMessage] = [ SystemPromptMessage(content=system_prompt), UserPromptMessage( content=json.dumps( diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 11414832e38..9dba193f751 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any, cast +from typing import NotRequired, TypedDict, cast from core.prompt.simple_prompt_transform import ModelMode from graphon.model_runtime.entities import ( @@ -13,19 +13,46 @@ from graphon.model_runtime.entities import ( ) +class SavedPromptFile(TypedDict): + type: str + data: str + detail: NotRequired[str] + format: NotRequired[str] + + +class SavedPromptToolCallFunction(TypedDict): + name: str + arguments: str + + +class SavedPromptToolCall(TypedDict): + id: str + type: str + function: SavedPromptToolCallFunction + + +class SavedPrompt(TypedDict): + role: str + text: str + files: NotRequired[list[SavedPromptFile]] + tool_calls: NotRequired[list[SavedPromptToolCall]] + + class PromptMessageUtil: @staticmethod - def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]): + def prompt_messages_to_prompt_for_saving( + model_mode: str, prompt_messages: Sequence[PromptMessage] + ) -> list[SavedPrompt]: """ Prompt messages to prompt for saving. :param model_mode: model mode :param prompt_messages: prompt messages :return: """ - prompts = [] + prompts: list[SavedPrompt] = [] if model_mode == ModelMode.CHAT: - tool_calls = [] for prompt_message in prompt_messages: + tool_calls: list[SavedPromptToolCall] = [] if prompt_message.role == PromptMessageRole.USER: role = "user" elif prompt_message.role == PromptMessageRole.ASSISTANT: @@ -50,7 +77,7 @@ class PromptMessageUtil: continue text = "" - files = [] + files: list[SavedPromptFile] = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: match content: @@ -77,7 +104,7 @@ class PromptMessageUtil: else: text = cast(str, prompt_message.content) - prompt = {"role": role, "text": text, "files": files} + prompt: SavedPrompt = {"role": role, "text": text, "files": files} if tool_calls: prompt["tool_calls"] = tool_calls @@ -86,14 +113,14 @@ class PromptMessageUtil: else: prompt_message = prompt_messages[0] text = "" - files = [] + prompt_files: list[SavedPromptFile] = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: if content.type == PromptMessageContentType.TEXT: text += content.data else: content = cast(ImagePromptMessageContent, content) - files.append( + prompt_files.append( { "type": "image", "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], @@ -103,13 +130,13 @@ class PromptMessageUtil: else: text = cast(str, prompt_message.content) - params: dict[str, Any] = { + params: SavedPrompt = { "role": "user", "text": text, } - if files: - params["files"] = files + if prompt_files: + params["files"] = prompt_files prompts.append(params) diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py index ae7bebcb9bb..1f129b8ed96 100644 --- a/api/core/rag/extractor/watercrawl/provider.py +++ b/api/core/rag/extractor/watercrawl/provider.py @@ -105,6 +105,8 @@ class WaterCrawlProvider: def scrape_url(self, url: str) -> WatercrawlDocumentData: response = self.client.scrape_url(url=url, sync=True, prefetched=True) + if not isinstance(response, dict): + raise ValueError("Invalid scrape response. Expected a JSON dictionary.") return self._structure_data(response) def _structure_data(self, result_object: dict[str, Any]) -> WatercrawlDocumentData: diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index db38990c146..c2edc7c4a7b 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -15,6 +15,8 @@ from urllib.parse import urlparse from docx import Document as DocxDocument from docx.oxml.ns import qn +from docx.table import Table +from docx.text.paragraph import Paragraph from docx.text.run import Run from configs import dify_config @@ -286,10 +288,10 @@ class WordExtractor(BaseExtractor): return "".join(paragraph_content).strip() - def parse_docx(self, docx_path): + def parse_docx(self, docx_path: str) -> str: doc = DocxDocument(docx_path) - content = [] + content: list[str] = [] image_map = self._extract_images_from_docx(doc) @@ -445,18 +447,11 @@ class WordExtractor(BaseExtractor): process_hyperlink(child, paragraph_content) return "".join(paragraph_content) if paragraph_content else "" - paragraphs = doc.paragraphs.copy() - tables = doc.tables.copy() - for element in doc.element.body: - if hasattr(element, "tag"): - if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph - para = paragraphs.pop(0) - parsed_paragraph = parse_paragraph(para) - if parsed_paragraph.strip(): - content.append(parsed_paragraph) - else: - content.append("\n") - elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table - table = tables.pop(0) - content.append(self._table_to_markdown(table, image_map)) + for block in doc.iter_inner_content(): + match block: + case Paragraph(): + parsed_paragraph = parse_paragraph(block) + content.append(parsed_paragraph if parsed_paragraph.strip() else "\n") + case Table(): + content.append(self._table_to_markdown(block, image_map)) return "\n".join(content) diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index b0e66c83fe2..8da401b226a 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -16,11 +16,9 @@ from sqlalchemy import select from configs import dify_config from core.entities.knowledge_entities import PreviewDetail from core.file import remote_fetcher -from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import AttachmentDocument, Document -from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.splitter.fixed_text_splitter import ( EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter, @@ -99,18 +97,6 @@ class BaseIndexProcessor(ABC): def format_preview(self, chunks: Any) -> Mapping[str, Any]: raise NotImplementedError - @abstractmethod - def retrieve( - self, - retrieval_method: RetrievalMethod, - query: str, - dataset: Dataset, - top_k: int, - score_threshold: float, - reranking_model: RerankingModelDict, - ) -> list[Document]: - raise NotImplementedError - def _get_splitter( self, processing_rule_mode: str, diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 7c7e8ab09da..c4f28ae2164 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -16,9 +16,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.model_manager import ModelInstance from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.cleaner.clean_processor import CleanProcessor -from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.keyword.keyword_factory import Keyword -from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.entities import Rule @@ -28,7 +26,6 @@ from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk -from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from core.workflow.file_reference import build_file_reference from extensions.ext_database import db @@ -182,35 +179,6 @@ class ParagraphIndexProcessor(BaseIndexProcessor): else: keyword.delete() - @override - def retrieve( - self, - retrieval_method: RetrievalMethod, - query: str, - dataset: Dataset, - top_k: int, - score_threshold: float, - reranking_model: RerankingModelDict, - ) -> list[Document]: - # Set search parameters. - results = RetrievalService.retrieve( - retrieval_method=retrieval_method, - dataset_id=dataset.id, - query=query, - top_k=top_k, - score_threshold=score_threshold, - reranking_model=reranking_model, - ) - # Organize results. - docs = [] - for result in results: - metadata = result.metadata - metadata["score"] = result.score - if result.score >= score_threshold: - doc = Document(page_content=result.page_content, metadata=metadata) - docs.append(doc) - return docs - @override def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None: documents: list[Any] = [] diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index bf9145def1d..9c186a9f046 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -12,8 +12,6 @@ from core.db.session_factory import session_factory from core.entities.knowledge_entities import PreviewDetail from core.model_manager import ModelInstance from core.rag.cleaner.clean_processor import CleanProcessor -from core.rag.data_post_processor.data_post_processor import RerankingModelDict -from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.entities import ParentMode, Rule @@ -23,7 +21,6 @@ from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk -from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from libs import helper from models import Account @@ -223,35 +220,6 @@ class ParentChildIndexProcessor(BaseIndexProcessor): ) db.session.commit() - @override - def retrieve( - self, - retrieval_method: RetrievalMethod, - query: str, - dataset: Dataset, - top_k: int, - score_threshold: float, - reranking_model: RerankingModelDict, - ) -> list[Document]: - # Set search parameters. - results = RetrievalService.retrieve( - retrieval_method=retrieval_method, - dataset_id=dataset.id, - query=query, - top_k=top_k, - score_threshold=score_threshold, - reranking_model=reranking_model, - ) - # Organize results. - docs = [] - for result in results: - metadata = result.metadata - metadata["score"] = result.score - if result.score >= score_threshold: - doc = Document(page_content=result.page_content, metadata=metadata) - docs.append(doc) - return docs - def _split_child_nodes( self, document_node: Document, diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 7d1e7333a8b..253acebc2c6 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -15,8 +15,6 @@ from core.db.session_factory import session_factory from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor -from core.rag.data_post_processor.data_post_processor import RerankingModelDict -from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.entities import Rule @@ -25,7 +23,6 @@ from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk -from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.account import Account @@ -187,35 +184,6 @@ class QAIndexProcessor(BaseIndexProcessor): else: vector.delete() - @override - def retrieve( - self, - retrieval_method: RetrievalMethod, - query: str, - dataset: Dataset, - top_k: int, - score_threshold: float, - reranking_model: RerankingModelDict, - ): - # Set search parameters. - results = RetrievalService.retrieve( - retrieval_method=retrieval_method, - dataset_id=dataset.id, - query=query, - top_k=top_k, - score_threshold=score_threshold, - reranking_model=reranking_model, - ) - # Organize results. - docs = [] - for result in results: - metadata = result.metadata - metadata["score"] = result.score - if result.score >= score_threshold: - doc = Document(page_content=result.page_content, metadata=metadata) - docs.append(doc) - return docs - @override def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None: qa_chunks = QAStructureChunk.model_validate(chunks) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index f3fd2d4d8ff..f4e850d34ed 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -609,7 +609,7 @@ class DatasetRetrieval: metadata_filter_document_ids: dict[str, list[str]] | None = None, metadata_condition: MetadataFilteringCondition | None = None, ): - tools = [] + tools: list[PromptMessageTool] = [] for dataset in available_datasets: description = dataset.description if not description: @@ -1162,7 +1162,7 @@ class DatasetRetrieval: :param invoke_from: invoke from :param hit_callback: hit callback """ - tools = [] + tools: list[DatasetRetrieverBaseTool] = [] available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index fce065eda9d..feb3bc7a4cb 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -30,6 +30,11 @@ class CelerySSLOptionsDict(TypedDict): ssl_keyfile: str | None +class CeleryBeatScheduleEntry(TypedDict): + task: str + schedule: crontab | timedelta + + def get_celery_ssl_options() -> CelerySSLOptionsDict | None: """Get SSL configuration for Celery broker/backend connections.""" # Only apply SSL if we're using Redis as broker/backend @@ -152,7 +157,7 @@ def init_app(app: DifyApp) -> Celery: day = dify_config.CELERY_BEAT_SCHEDULER_TIME # if you add a new task, please add the switch to CeleryScheduleTasksConfig - beat_schedule = {} + beat_schedule: dict[str, CeleryBeatScheduleEntry] = {} if dify_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK: imports.append("schedule.clean_embedding_cache_task") beat_schedule["clean_embedding_cache_task"] = { diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py index 9b9b4f2c15f..8598c579b3c 100644 --- a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta from typing import Any, cast, override import mlflow -from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType +from mlflow.entities import Document, LiveSpan, Span, SpanEvent, SpanStatusCode, SpanType from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey from mlflow.tracing.fluent import start_span_no_context, update_current_trace from mlflow.tracing.provider import detach_span_from_context, set_span_in_context @@ -31,6 +31,8 @@ from models.workflow import WorkflowNodeExecutionModel logger = logging.getLogger(__name__) +type SpanAttributes = dict[str, object] + def datetime_to_nanoseconds(dt: datetime | None) -> int | None: """Convert datetime to nanosecond timestamp for MLflow API""" @@ -39,6 +41,32 @@ def datetime_to_nanoseconds(dt: datetime | None) -> int | None: return int(dt.timestamp() * 1_000_000_000) +def _start_span_no_context( + *, + name: str, + span_type: str, + parent_span: LiveSpan | None = None, + inputs: object | None = None, + attributes: SpanAttributes | None = None, + start_time_ns: int | None = None, +) -> LiveSpan: + """Start an MLflow span while preserving structured Dify attributes. + + MLflow 3.11 annotates `start_span_no_context(..., attributes=...)` as `dict[str, str]`, + but the implementation immediately calls `LiveSpan.set_attributes(dict[str, Any])`. + `LiveSpan` JSON-serializes arbitrary values before storing them in OpenTelemetry, and + reserved attributes like `mlflow.chat.tokenUsage` are expected to round-trip as dicts. + """ + return start_span_no_context( + name=name, + span_type=span_type, + parent_span=parent_span, + inputs=inputs, + attributes=cast(dict[str, str] | None, attributes), + start_time_ns=start_time_ns, + ) + + class MLflowDataTrace(BaseTraceInstance): def __init__(self, config: MLflowConfig | DatabricksConfig): super().__init__(config) @@ -119,7 +147,7 @@ class MLflowDataTrace(BaseTraceInstance): if trace_info.query: workflow_inputs["query"] = trace_info.query - workflow_span = start_span_no_context( + workflow_span = _start_span_no_context( name=TraceTaskName.WORKFLOW_TRACE.value, span_type=SpanType.CHAIN, inputs=workflow_inputs, @@ -139,7 +167,7 @@ class MLflowDataTrace(BaseTraceInstance): # Create child spans for workflow nodes for node in self._get_workflow_nodes(trace_info.workflow_run_id): inputs = None - attributes = { + attributes: SpanAttributes = { "node_id": node.id, "node_type": node.node_type, "status": node.status, @@ -157,7 +185,7 @@ class MLflowDataTrace(BaseTraceInstance): if not inputs: inputs = JSON_DICT_ADAPTER.validate_json(node.inputs) if node.inputs else {} - node_span = start_span_no_context( + node_span = _start_span_no_context( name=node.title, span_type=self._get_node_span_type(node.node_type), parent_span=workflow_span, @@ -212,7 +240,7 @@ class MLflowDataTrace(BaseTraceInstance): end_time_ns=datetime_to_nanoseconds(trace_info.end_time), ) - def _parse_llm_inputs_and_attributes(self, node: WorkflowNodeExecutionModel) -> tuple[Any, dict]: + def _parse_llm_inputs_and_attributes(self, node: WorkflowNodeExecutionModel) -> tuple[object, SpanAttributes]: """Parse LLM inputs and attributes from LLM workflow node""" if node.process_data is None: return {}, {} @@ -266,16 +294,16 @@ class MLflowDataTrace(BaseTraceInstance): base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") file_list.append(f"{base_url}/{message_file_data.url}") - span = start_span_no_context( + span = _start_span_no_context( name=TraceTaskName.MESSAGE_TRACE.value, span_type=SpanType.LLM, - inputs=self._parse_prompts(trace_info.inputs), # type: ignore[arg-type] + inputs=self._parse_prompts(trace_info.inputs), attributes={ - "message_id": trace_info.message_id, # type: ignore[dict-item] + "message_id": trace_info.message_id, "model_provider": trace_info.message_data.model_provider, "model_id": trace_info.message_data.model_id, "conversation_mode": trace_info.conversation_mode, - "file_list": file_list, # type: ignore[dict-item] + "file_list": file_list, "total_price": trace_info.message_data.total_price, **trace_info.metadata, }, @@ -330,15 +358,15 @@ class MLflowDataTrace(BaseTraceInstance): return metadata.get("from_account_id") # type: ignore[return-value] def tool_trace(self, trace_info: ToolTraceInfo): - span = start_span_no_context( + span = _start_span_no_context( name=trace_info.tool_name, span_type=SpanType.TOOL, - inputs=trace_info.tool_inputs, # type: ignore[arg-type] + inputs=trace_info.tool_inputs, attributes={ - "message_id": trace_info.message_id, # type: ignore[dict-item] - "metadata": trace_info.metadata, # type: ignore[dict-item] - "tool_config": trace_info.tool_config, # type: ignore[dict-item] - "tool_parameters": trace_info.tool_parameters, # type: ignore[dict-item] + "message_id": trace_info.message_id, + "metadata": trace_info.metadata, + "tool_config": trace_info.tool_config, + "tool_parameters": trace_info.tool_parameters, }, start_time_ns=datetime_to_nanoseconds(trace_info.start_time), ) @@ -367,13 +395,13 @@ class MLflowDataTrace(BaseTraceInstance): return start_time = trace_info.start_time or trace_info.message_data.created_at - span = start_span_no_context( + span = _start_span_no_context( name=TraceTaskName.MODERATION_TRACE.value, span_type=SpanType.TOOL, inputs=trace_info.inputs or {}, attributes={ - "message_id": trace_info.message_id, # type: ignore[dict-item] - "metadata": trace_info.metadata, # type: ignore[dict-item] + "message_id": trace_info.message_id, + "metadata": trace_info.metadata, }, start_time_ns=datetime_to_nanoseconds(start_time), ) @@ -391,13 +419,13 @@ class MLflowDataTrace(BaseTraceInstance): if trace_info.message_data is None: return - span = start_span_no_context( + span = _start_span_no_context( name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, span_type=SpanType.RETRIEVER, inputs=trace_info.inputs, attributes={ - "message_id": trace_info.message_id, # type: ignore[dict-item] - "metadata": trace_info.metadata, # type: ignore[dict-item] + "message_id": trace_info.message_id, + "metadata": trace_info.metadata, }, start_time_ns=datetime_to_nanoseconds(trace_info.start_time), ) @@ -410,15 +438,15 @@ class MLflowDataTrace(BaseTraceInstance): start_time = trace_info.start_time or trace_info.message_data.created_at end_time = trace_info.end_time or trace_info.message_data.updated_at - span = start_span_no_context( + span = _start_span_no_context( name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, span_type=SpanType.TOOL, inputs=trace_info.inputs, attributes={ - "message_id": trace_info.message_id, # type: ignore[dict-item] - "model_provider": trace_info.model_provider, # type: ignore[dict-item] - "model_id": trace_info.model_id, # type: ignore[dict-item] - "total_tokens": trace_info.total_tokens or 0, # type: ignore[dict-item] + "message_id": trace_info.message_id, + "model_provider": trace_info.model_provider, + "model_id": trace_info.model_id, + "total_tokens": trace_info.total_tokens or 0, }, start_time_ns=datetime_to_nanoseconds(start_time), ) @@ -439,11 +467,11 @@ class MLflowDataTrace(BaseTraceInstance): span.end(outputs=trace_info.suggested_question, end_time_ns=datetime_to_nanoseconds(end_time)) def generate_name_trace(self, trace_info: GenerateNameTraceInfo): - span = start_span_no_context( + span = _start_span_no_context( name=TraceTaskName.GENERATE_NAME_TRACE.value, span_type=SpanType.CHAIN, inputs=trace_info.inputs, - attributes={"message_id": trace_info.message_id}, # type: ignore[dict-item] + attributes={"message_id": trace_info.message_id}, start_time_ns=datetime_to_nanoseconds(trace_info.start_time), ) span.end(outputs=trace_info.outputs, end_time_ns=datetime_to_nanoseconds(trace_info.end_time)) diff --git a/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py index 324f894b252..08c1b52c88f 100644 --- a/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch import pytest from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig from dify_trace_mlflow.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds +from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, @@ -361,6 +362,7 @@ class TestWorkflowTrace: assert inputs["query"] == "hello" def test_workflow_with_llm_node(self, trace_instance, mock_tracing, mock_db): + usage = {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15} llm_node = _make_node( node_type=BuiltinNodeTypes.LLM, process_data=json.dumps( @@ -369,7 +371,7 @@ class TestWorkflowTrace: "model_name": "gpt-4", "model_provider": "openai", "finish_reason": "stop", - "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}, + "usage": usage, } ), outputs='{"text": "hello world"}', @@ -383,6 +385,14 @@ class TestWorkflowTrace: trace_instance.workflow_trace(_make_workflow_trace_info()) assert mock_tracing["start"].call_count == 2 + node_start_call = mock_tracing["start"].call_args_list[1] + attrs = node_start_call.kwargs["attributes"] + assert attrs[SpanAttributeKey.CHAT_USAGE] == { + TokenUsageKey.INPUT_TOKENS: 5, + TokenUsageKey.OUTPUT_TOKENS: 10, + TokenUsageKey.TOTAL_TOKENS: 15, + } + assert attrs["usage"] == usage node_span.end.assert_called_once() workflow_span.end.assert_called_once() @@ -631,6 +641,27 @@ class TestMessageTrace: assert "http://files.test/path/to/file.png" in attrs["file_list"] assert "existing_file.txt" in attrs["file_list"] + def test_message_trace_preserves_structured_span_attributes(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + trace_info = _make_message_trace_info( + metadata={ + "conversation_id": "c1", + "from_account_id": "a1", + "routing": {"node": "answer", "score": 0.7}, + }, + file_list=["existing_file.txt"], + ) + trace_instance.message_trace(trace_info) + + attrs = mock_tracing["start"].call_args.kwargs["attributes"] + assert attrs["message_id"] == "msg-1" + assert attrs["total_price"] == 0.01 + assert attrs["routing"] == {"node": "answer", "score": 0.7} + assert attrs["file_list"] == ["existing_file.txt"] + def test_message_trace_file_list_none(self, trace_instance, mock_tracing, mock_db): span = MagicMock() mock_tracing["start"].return_value = span diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index c97014ebcfe..cd451dec259 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -1,9 +1,3 @@ -core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py -core/llm_generator/llm_generator.py -providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py -core/prompt/utils/prompt_message_util.py -core/rag/retrieval/dataset_retrieval.py -extensions/ext_celery.py providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_factory.py providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_openapi.py @@ -59,11 +53,6 @@ providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py providers/vdb/vdb-vikingdb/tests/unit_tests/test_vikingdb_vector.py providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py -core/rag/extractor/watercrawl/provider.py -core/rag/extractor/word_extractor.py -core/rag/index_processor/processor/paragraph_index_processor.py -core/rag/index_processor/processor/parent_child_index_processor.py -core/rag/index_processor/processor/qa_index_processor.py core/tools/mcp_tool/provider.py core/tools/plugin_tool/provider.py core/tools/workflow_as_tool/provider.py diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py index 0f2c79f9fc7..18382f053be 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -1,7 +1,9 @@ from __future__ import annotations +from collections.abc import Generator, Sequence from datetime import UTC, datetime -from types import SimpleNamespace +from threading import Thread +from typing import cast from unittest.mock import Mock import pytest @@ -13,8 +15,11 @@ from core.app.app_config.entities import ( ModelConfigEntity, PromptTemplateEntity, ) +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, CompletionAppGenerateEntity, InvokeFrom from core.app.entities.queue_entities import ( + AppQueueEvent, + MessageQueueMessage, QueueAgentMessageEvent, QueueAgentThoughtEvent, QueueAnnotationReplyEvent, @@ -26,23 +31,33 @@ from core.app.entities.queue_entities import ( QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent, + WorkflowQueueMessage, ) from core.app.entities.task_entities import ( + AgentMessageStreamResponse, + AgentThoughtStreamResponse, ChatbotAppStreamResponse, CompletionAppStreamResponse, ErrorStreamResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, + MessageFileStreamResponse, + MessageReplaceStreamResponse, + MessageStreamResponse, PingStreamResponse, + StreamEvent, ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline -from core.base.tts import AudioTrunk +from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.entities.trace_entity import TraceTaskName -from graphon.file import FileTransferMethod +from core.ops.ops_trace_manager import TraceQueueManager +from extensions.storage.storage_type import StorageType +from graphon.file import FileTransferMethod, FileType from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent -from models.model import AppMode +from models.enums import CreatorUserRole +from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile class _DummyModelConf: @@ -50,6 +65,150 @@ class _DummyModelConf: self.model = "mock" +class _FakeDb: + engine: object = object() + + +class _UnknownQueueEvent: + pass + + +class _AnnotationReply: + def __init__(self, content: str) -> None: + self.content = content + + +class _ModelConfigMode: + def __init__(self, mode: str) -> None: + self.mode = mode + + +class _ProviderModelBundle: + def __init__(self, model_type_instance: object) -> None: + self.model_type_instance = model_type_instance + + +class _AudioPublisher: + def __init__(self, status: str, audio: str) -> None: + self._audio = AudioTrunk(status, audio) + + def check_and_get_audio(self) -> AudioTrunk: + return self._audio + + +class _TraceManagerDouble: + def __init__(self) -> None: + self.add_trace_task = Mock() + + +class _FakeQueueManager(AppQueueManager): + def __init__(self) -> None: + self._events: list[MessageQueueMessage | WorkflowQueueMessage] = [] + self.published_events: list[AppQueueEvent] = [] + + def set_events(self, events: list[MessageQueueMessage | WorkflowQueueMessage]) -> None: + self._events = events + + def listen(self) -> Generator[MessageQueueMessage | WorkflowQueueMessage]: + yield from self._events + + def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + self.published_events.append(event) + + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + self.published_events.append(event) + + +def _queue_message(event: AppQueueEvent) -> MessageQueueMessage: + return MessageQueueMessage( + task_id="task", + app_mode=AppMode.CHAT.value, + message_id="msg", + conversation_id="conv", + event=event, + ) + + +def _unknown_queue_message() -> MessageQueueMessage: + return MessageQueueMessage.model_construct( + task_id="task", + app_mode=AppMode.CHAT.value, + message_id="msg", + conversation_id="conv", + event=cast(AppQueueEvent, _UnknownQueueEvent()), + ) + + +def _make_conversation(app_mode: AppMode) -> Conversation: + conversation = Conversation() + conversation.id = "conv" + conversation.mode = app_mode + return conversation + + +def _make_message() -> Message: + message = Message() + message.id = "msg" + message.created_at = datetime.now(UTC) + return message + + +def _message_file( + *, + file_id: str, + transfer_method: FileTransferMethod, + url: str | None, + upload_file_id: str | None, + file_type: FileType = FileType.IMAGE, +) -> MessageFile: + message_file = MessageFile( + message_id="msg", + type=file_type, + transfer_method=transfer_method, + created_by_role=CreatorUserRole.ACCOUNT, + created_by="user", + url=url, + upload_file_id=upload_file_id, + ) + message_file.id = file_id + return message_file + + +def _upload_file(*, file_id: str, name: str, mime_type: str, size: int, extension: str) -> UploadFile: + upload_file = UploadFile( + tenant_id="tenant", + storage_type=StorageType.LOCAL, + key=f"uploads/{file_id}", + name=name, + size=size, + extension=extension, + mime_type=mime_type, + created_by_role=CreatorUserRole.ACCOUNT, + created_by="user", + created_at=datetime.now(UTC), + used=False, + ) + upload_file.id = file_id + return upload_file + + +def _agent_thought() -> MessageAgentThought: + thought = MessageAgentThought( + message_id="msg", + position=1, + created_by_role=CreatorUserRole.ACCOUNT, + created_by="user", + thought="t", + observation="o", + tool="tool", + tool_labels_str="{}", + tool_input="input", + message_files="[]", + ) + thought.id = "thought" + return thought + + def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig: return EasyUIBasedAppConfig( tenant_id="tenant", @@ -89,14 +248,42 @@ def _make_entity(entity_cls, app_mode: AppMode): ) +def _make_pipeline( + entity_cls: type[ChatAppGenerateEntity] | type[CompletionAppGenerateEntity] = ChatAppGenerateEntity, + app_mode: AppMode = AppMode.CHAT, + *, + stream: bool = True, +) -> tuple[EasyUIBasedGenerateTaskPipeline, _FakeQueueManager]: + queue_manager = _FakeQueueManager() + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(entity_cls, app_mode), + queue_manager=queue_manager, + conversation=_make_conversation(app_mode), + message=_make_message(), + stream=stream, + ) + return pipeline, queue_manager + + +def _set_queue_events( + pipeline: EasyUIBasedGenerateTaskPipeline, + events: Sequence[MessageQueueMessage | WorkflowQueueMessage], +) -> None: + cast(_FakeQueueManager, pipeline.queue_manager).set_events(list(events)) + + +def _set_method(obj: object, name: str, value: object) -> None: + object.__setattr__(obj, name, value) + + class TestEasyUiBasedGenerateTaskPipeline: def test_to_blocking_response_chat(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, @@ -111,12 +298,12 @@ class TestEasyUiBasedGenerateTaskPipeline: assert response.data.answer == "answer" def test_to_blocking_response_completion(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.COMPLETION) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, @@ -131,12 +318,12 @@ class TestEasyUiBasedGenerateTaskPipeline: assert response.data.answer == "answer" def test_listen_audio_msg_returns_none_when_no_publisher(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, @@ -145,12 +332,12 @@ class TestEasyUiBasedGenerateTaskPipeline: assert pipeline._listen_audio_msg(publisher=None, task_id="task") is None def test_process_stream_response_handles_chunks_and_end(self, monkeypatch: pytest.MonkeyPatch): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, @@ -174,19 +361,24 @@ class TestEasyUiBasedGenerateTaskPipeline: ) events = [ - SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk)), - SimpleNamespace(event=QueueMessageReplaceEvent(text="replace", reason="output_moderation")), - SimpleNamespace(event=QueuePingEvent()), - SimpleNamespace(event=QueueMessageEndEvent(llm_result=llm_result)), + _queue_message(QueueLLMChunkEvent(chunk=chunk)), + _queue_message(QueueMessageReplaceEvent(text="replace", reason="output_moderation")), + _queue_message(QueuePingEvent()), + _queue_message(QueueMessageEndEvent(llm_result=llm_result)), ] - pipeline.queue_manager.listen = lambda: iter(events) - pipeline._message_cycle_manager.get_message_event_type = lambda message_id: None - pipeline._message_cycle_manager.message_to_stream_response = lambda **kwargs: "chunk" - pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace" - pipeline.handle_output_moderation_when_task_finished = lambda completion: None - pipeline._message_end_to_stream_response = lambda: "end" - pipeline._save_message = lambda **kwargs: None + _set_queue_events(pipeline, events) + + def _message_event_type(message_id: str) -> StreamEvent: + return StreamEvent.MESSAGE + + def _message_end() -> MessageEndStreamResponse: + return MessageEndStreamResponse(task_id="task", id="msg") + + _set_method(pipeline._message_cycle_manager, "get_message_event_type", _message_event_type) + _set_method(pipeline, "handle_output_moderation_when_task_finished", lambda completion: None) + _set_method(pipeline, "_message_end_to_stream_response", _message_end) + _set_method(pipeline, "_save_message", lambda **kwargs: None) class _Session: def __init__(self, *args, **kwargs): @@ -207,28 +399,30 @@ class TestEasyUiBasedGenerateTaskPipeline: ) monkeypatch.setattr( "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", - SimpleNamespace(engine=object()), + _FakeDb(), ) responses = list(pipeline._process_stream_response(publisher=None)) - assert "chunk" in responses - assert "replace" in responses + message_response = next(item for item in responses if isinstance(item, MessageStreamResponse)) + assert message_response.answer == "hiyo" + replace_response = next(item for item in responses if isinstance(item, MessageReplaceStreamResponse)) + assert replace_response.answer == "replace" assert any(isinstance(item, PingStreamResponse) for item in responses) - assert responses[-1] == "end" + assert isinstance(responses[-1], MessageEndStreamResponse) + assert pipeline._task_state.llm_result.message.content == "done" def test_handle_output_moderation_chunk_directs_output(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, ) - events: list[object] = [] class _Moderation: def should_direct_output(self): @@ -237,18 +431,18 @@ class TestEasyUiBasedGenerateTaskPipeline: def get_final_output(self): return "final" - pipeline.output_moderation_handler = _Moderation() - pipeline.queue_manager.publish = lambda event, publish_from: events.append(event) + _set_method(pipeline, "output_moderation_handler", _Moderation()) result = pipeline._handle_output_moderation_chunk("token") assert result is True + events = cast(_FakeQueueManager, pipeline.queue_manager).published_events assert any(isinstance(event, QueueLLMChunkEvent) for event in events) assert any(isinstance(event, QueueStopEvent) for event in events) def test_handle_stop_updates_usage(self, monkeypatch: pytest.MonkeyPatch): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() class _ModelType: def calc_response_usage(self, model, credentials, prompt_tokens, completion_tokens): @@ -263,7 +457,7 @@ class TestEasyUiBasedGenerateTaskPipeline: def __init__(self) -> None: self.model = "mock" self.credentials = {} - self.provider_model_bundle = SimpleNamespace(model_type_instance=_ModelType()) + self.provider_model_bundle = _ProviderModelBundle(model_type_instance=_ModelType()) app_config = _make_app_config(AppMode.CHAT) application_generate_entity = ChatAppGenerateEntity.model_construct( @@ -286,7 +480,7 @@ class TestEasyUiBasedGenerateTaskPipeline: pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=application_generate_entity, - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, @@ -315,46 +509,41 @@ class TestEasyUiBasedGenerateTaskPipeline: assert pipeline._task_state.llm_result.usage.completion_tokens == 5 def test_record_files_builds_file_payloads(self, monkeypatch: pytest.MonkeyPatch): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, ) message_files = [ - SimpleNamespace( - id="mf-1", - message_id="msg", + _message_file( + file_id="mf-1", transfer_method=FileTransferMethod.REMOTE_URL, url="http://example.com/a.png", upload_file_id=None, - type="image", ), - SimpleNamespace( - id="mf-2", - message_id="msg", + _message_file( + file_id="mf-2", transfer_method=FileTransferMethod.LOCAL_FILE, url="", upload_file_id="upload-1", - type="image", ), - SimpleNamespace( - id="mf-3", - message_id="msg", + _message_file( + file_id="mf-3", transfer_method=FileTransferMethod.TOOL_FILE, url="tool/file.bin", upload_file_id=None, - type="file", + file_type=FileType.CUSTOM, ), ] upload_files = [ - SimpleNamespace( - id="upload-1", + _upload_file( + file_id="upload-1", name="local.png", mime_type="image/png", size=123, @@ -389,7 +578,7 @@ class TestEasyUiBasedGenerateTaskPipeline: ) monkeypatch.setattr( "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", - SimpleNamespace(engine=object()), + _FakeDb(), ) monkeypatch.setattr( "core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url", @@ -407,12 +596,12 @@ class TestEasyUiBasedGenerateTaskPipeline: assert len(files) == 3 def test_process_stream_response_handles_annotation_and_error(self, monkeypatch: pytest.MonkeyPatch): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, @@ -428,20 +617,36 @@ class TestEasyUiBasedGenerateTaskPipeline: ) events = [ - SimpleNamespace(event=QueueAnnotationReplyEvent(message_annotation_id="ann")), - SimpleNamespace(event=QueueAgentThoughtEvent(agent_thought_id="thought")), - SimpleNamespace(event=QueueMessageFileEvent(message_file_id="file")), - SimpleNamespace(event=QueueAgentMessageEvent(chunk=agent_chunk)), - SimpleNamespace(event=QueueErrorEvent(error=ValueError("boom"))), + _queue_message(QueueAnnotationReplyEvent(message_annotation_id="ann")), + _queue_message(QueueAgentThoughtEvent(agent_thought_id="thought")), + _queue_message(QueueMessageFileEvent(message_file_id="file")), + _queue_message(QueueAgentMessageEvent(chunk=agent_chunk)), + _queue_message(QueueErrorEvent(error=ValueError("boom"))), ] - pipeline.queue_manager.listen = lambda: iter(events) - pipeline._message_cycle_manager.handle_annotation_reply = lambda event: SimpleNamespace(content="annotated") - pipeline._agent_thought_to_stream_response = lambda event: "thought" - pipeline._message_cycle_manager.message_file_to_stream_response = lambda event: "file" - pipeline._agent_message_to_stream_response = lambda **kwargs: "agent" - pipeline.handle_error = lambda **kwargs: ValueError("boom") - pipeline.error_to_stream_response = lambda err: err + _set_queue_events(pipeline, events) + + def _agent_thought_response(event: QueueAgentThoughtEvent) -> AgentThoughtStreamResponse: + return AgentThoughtStreamResponse(task_id="task", id=event.agent_thought_id, position=1) + + def _file_response(event: QueueMessageFileEvent) -> MessageFileStreamResponse: + return MessageFileStreamResponse( + task_id="task", id=event.message_file_id, type="image", belongs_to="user", url="file" + ) + + def _agent_message_response(answer: str, message_id: str) -> AgentMessageStreamResponse: + return AgentMessageStreamResponse(task_id="task", id=message_id, answer=answer) + + _set_method( + pipeline._message_cycle_manager, + "handle_annotation_reply", + lambda event: _AnnotationReply(content="annotated"), + ) + _set_method(pipeline, "_agent_thought_to_stream_response", _agent_thought_response) + _set_method(pipeline._message_cycle_manager, "message_file_to_stream_response", _file_response) + _set_method(pipeline, "_agent_message_to_stream_response", _agent_message_response) + _set_method(pipeline, "handle_error", lambda **kwargs: ValueError("boom")) + _set_method(pipeline, "error_to_stream_response", lambda err: ErrorStreamResponse(task_id="task", err=err)) class _Session: def __init__(self, *args, **kwargs): @@ -462,39 +667,32 @@ class TestEasyUiBasedGenerateTaskPipeline: ) monkeypatch.setattr( "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", - SimpleNamespace(engine=object()), + _FakeDb(), ) responses = list(pipeline._process_stream_response(publisher=None)) - assert "thought" in responses - assert "file" in responses - assert "agent" in responses - assert isinstance(responses[-1], ValueError) + assert any(isinstance(response, AgentThoughtStreamResponse) for response in responses) + assert any(isinstance(response, MessageFileStreamResponse) for response in responses) + agent_response = next(response for response in responses if isinstance(response, AgentMessageStreamResponse)) + assert agent_response.answer == "agent" + assert isinstance(responses[-1], ErrorStreamResponse) + assert isinstance(responses[-1].err, ValueError) assert pipeline._task_state.llm_result.message.content == "annotatedagent" def test_agent_thought_to_stream_response_returns_payload(self, monkeypatch: pytest.MonkeyPatch): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, ) - agent_thought = SimpleNamespace( - id="thought", - position=1, - thought="t", - observation="o", - tool="tool", - tool_labels={}, - tool_input="input", - files=[], - ) + agent_thought = _agent_thought() class _Session: def __init__(self, *args, **kwargs): @@ -515,7 +713,7 @@ class TestEasyUiBasedGenerateTaskPipeline: ) monkeypatch.setattr( "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", - SimpleNamespace(engine=object()), + _FakeDb(), ) response = pipeline._agent_thought_to_stream_response(QueueAgentThoughtEvent(agent_thought_id="thought")) @@ -524,18 +722,22 @@ class TestEasyUiBasedGenerateTaskPipeline: assert response.id == "thought" def test_process_routes_to_stream_and_starts_conversation_name_generation(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, ) pipeline._message_cycle_manager.generate_conversation_name = Mock(return_value=object()) - pipeline._wrapper_process_stream_response = lambda trace_manager: iter(["payload"]) - pipeline._to_stream_response = lambda generator: "streamed" + _set_method( + pipeline, + "_wrapper_process_stream_response", + lambda trace_manager: iter([PingStreamResponse(task_id="task")]), + ) + _set_method(pipeline, "_to_stream_response", lambda generator: "streamed") result = pipeline.process() @@ -545,18 +747,22 @@ class TestEasyUiBasedGenerateTaskPipeline: ) def test_process_routes_to_blocking_for_completion_mode(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.COMPLETION) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, ) pipeline._message_cycle_manager.generate_conversation_name = Mock() - pipeline._wrapper_process_stream_response = lambda trace_manager: iter(["payload"]) - pipeline._to_blocking_response = lambda generator: "blocking" + _set_method( + pipeline, + "_wrapper_process_stream_response", + lambda trace_manager: iter([PingStreamResponse(task_id="task")]), + ) + _set_method(pipeline, "_to_blocking_response", lambda generator: "blocking") result = pipeline.process() @@ -564,11 +770,11 @@ class TestEasyUiBasedGenerateTaskPipeline: pipeline._message_cycle_manager.generate_conversation_name.assert_not_called() def test_to_blocking_response_raises_error_stream_exception(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, @@ -581,11 +787,11 @@ class TestEasyUiBasedGenerateTaskPipeline: pipeline._to_blocking_response(_gen()) def test_to_blocking_response_raises_when_generator_ends_without_message_end(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, @@ -598,11 +804,11 @@ class TestEasyUiBasedGenerateTaskPipeline: pipeline._to_blocking_response(_gen()) def test_to_stream_response_wraps_completion_stream_events(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.COMPLETION) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, @@ -617,11 +823,11 @@ class TestEasyUiBasedGenerateTaskPipeline: assert response.message_id == "msg" def test_to_stream_response_wraps_chat_stream_events(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, @@ -636,16 +842,19 @@ class TestEasyUiBasedGenerateTaskPipeline: assert response.conversation_id == "conv" def test_listen_audio_msg_returns_audio_response_for_non_finish_audio(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, ) - publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk("responding", "abc")) + publisher = cast( + AppGeneratorTTSPublisher, + _AudioPublisher(status="responding", audio="abc"), + ) response = pipeline._listen_audio_msg(publisher=publisher, task_id="task") @@ -653,45 +862,49 @@ class TestEasyUiBasedGenerateTaskPipeline: assert response.audio == "abc" def test_listen_audio_msg_returns_none_for_finish_audio(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, ) - publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk("finish", "abc")) + publisher = cast( + AppGeneratorTTSPublisher, + _AudioPublisher(status="finish", audio="abc"), + ) assert pipeline._listen_audio_msg(publisher=publisher, task_id="task") is None def test_wrapper_process_stream_response_without_tts_publisher(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, ) - pipeline._process_stream_response = lambda publisher, trace_manager: iter(["payload"]) + payload = PingStreamResponse(task_id="task") + _set_method(pipeline, "_process_stream_response", lambda publisher, trace_manager: iter([payload])) responses = list(pipeline._wrapper_process_stream_response()) - assert responses == ["payload"] + assert responses == [payload] def test_wrapper_process_stream_response_with_tts_publisher(self, monkeypatch: pytest.MonkeyPatch): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) entity.app_config.app_model_config_dict = { "text_to_speech": {"autoPlay": "enabled", "enabled": True, "voice": "v", "language": "en"} } pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=entity, - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, @@ -703,8 +916,9 @@ class TestEasyUiBasedGenerateTaskPipeline: inline_audio = MessageAudioStreamResponse(task_id="task", audio="inline") audio_calls = iter([inline_audio, None]) - pipeline._listen_audio_msg = lambda publisher, task_id: next(audio_calls) - pipeline._process_stream_response = lambda publisher, trace_manager: iter(["payload"]) + payload = PingStreamResponse(task_id="task") + _set_method(pipeline, "_listen_audio_msg", lambda publisher, task_id: next(audio_calls)) + _set_method(pipeline, "_process_stream_response", lambda publisher, trace_manager: iter([payload])) monkeypatch.setattr( "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.AppGeneratorTTSPublisher", lambda tenant_id, voice, language: _Publisher(), @@ -713,19 +927,19 @@ class TestEasyUiBasedGenerateTaskPipeline: responses = list(pipeline._wrapper_process_stream_response()) assert responses[0] == inline_audio - assert responses[1] == "payload" + assert responses[1] == payload assert isinstance(responses[-1], MessageAudioEndStreamResponse) def test_wrapper_process_stream_response_timeout_yields_audio_chunk(self, monkeypatch: pytest.MonkeyPatch): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) entity.app_config.app_model_config_dict = { "text_to_speech": {"autoPlay": "enabled", "enabled": True, "voice": "v", "language": "en"} } pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=entity, - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, @@ -744,7 +958,7 @@ class TestEasyUiBasedGenerateTaskPipeline: clock["value"] += 0.1 return clock["value"] - pipeline._process_stream_response = lambda publisher, trace_manager: iter([]) + _set_method(pipeline, "_process_stream_response", lambda publisher, trace_manager: iter([])) monkeypatch.setattr( "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.AppGeneratorTTSPublisher", lambda tenant_id, voice, language: _Publisher(), @@ -758,24 +972,30 @@ class TestEasyUiBasedGenerateTaskPipeline: assert isinstance(responses[-1], MessageAudioEndStreamResponse) def test_process_stream_response_handles_stop_event_and_output_replacement(self, monkeypatch: pytest.MonkeyPatch): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, ) pipeline._task_state.llm_result.message.content = "raw answer" - pipeline.queue_manager.listen = lambda: iter( - [SimpleNamespace(event=QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))] - ) + _set_queue_events(pipeline, [_queue_message(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))]) pipeline._handle_stop = Mock() - pipeline.handle_output_moderation_when_task_finished = lambda answer: "moderated answer" - pipeline._message_cycle_manager.message_replace_to_stream_response = lambda answer: f"replace:{answer}" - pipeline._save_message = lambda **kwargs: None - pipeline._message_end_to_stream_response = lambda: "end" + _set_method(pipeline, "handle_output_moderation_when_task_finished", lambda answer: "moderated answer") + _set_method( + pipeline._message_cycle_manager, + "message_replace_to_stream_response", + lambda answer: MessageReplaceStreamResponse(task_id="task", answer=answer, reason=""), + ) + _set_method(pipeline, "_save_message", lambda **kwargs: None) + _set_method( + pipeline, + "_message_end_to_stream_response", + lambda: MessageEndStreamResponse(task_id="task", id="msg"), + ) class _Session: def __init__(self, *args, **kwargs): @@ -793,20 +1013,22 @@ class TestEasyUiBasedGenerateTaskPipeline: monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) monkeypatch.setattr( "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", - SimpleNamespace(engine=object()), + _FakeDb(), ) responses = list(pipeline._process_stream_response(publisher=None)) - assert responses == ["replace:moderated answer", "end"] + assert isinstance(responses[0], MessageReplaceStreamResponse) + assert responses[0].answer == "moderated answer" + assert isinstance(responses[1], MessageEndStreamResponse) pipeline._handle_stop.assert_called_once() def test_process_stream_response_handles_retriever_unknown_and_empty_chunk(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, @@ -823,12 +1045,13 @@ class TestEasyUiBasedGenerateTaskPipeline: handled["retriever"] += 1 pipeline._message_cycle_manager.handle_retriever_resources = _handle_retriever_resources - pipeline.queue_manager.listen = lambda: iter( + _set_queue_events( + pipeline, [ - SimpleNamespace(event=retriever_event), - SimpleNamespace(event=SimpleNamespace()), - SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk)), - ] + _queue_message(retriever_event), + _unknown_queue_message(), + _queue_message(QueueLLMChunkEvent(chunk=chunk)), + ], ) responses = list(pipeline._process_stream_response(publisher=None)) @@ -837,11 +1060,11 @@ class TestEasyUiBasedGenerateTaskPipeline: assert handled["retriever"] == 1 def test_process_stream_response_skips_when_output_moderation_directs_chunk(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, @@ -852,76 +1075,103 @@ class TestEasyUiBasedGenerateTaskPipeline: delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content="x")), ) pipeline._handle_output_moderation_chunk = lambda text: True - pipeline.queue_manager.listen = lambda: iter([SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk))]) + _set_queue_events(pipeline, [_queue_message(QueueLLMChunkEvent(chunk=chunk))]) responses = list(pipeline._process_stream_response(publisher=None)) assert responses == [] def test_process_stream_response_ignores_unsupported_chunk_content_types(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, ) - chunk = SimpleNamespace( + chunk = LLMResultChunk.model_construct( prompt_messages=[], - delta=SimpleNamespace(message=SimpleNamespace(content=[object(), "ok"])), - ) - pipeline._message_cycle_manager.get_message_event_type = lambda message_id: None - pipeline._message_cycle_manager.message_to_stream_response = lambda **kwargs: kwargs["answer"] - pipeline.queue_manager.listen = lambda: iter( - [SimpleNamespace(event=QueueLLMChunkEvent.model_construct(chunk=chunk))] + delta=LLMResultChunkDelta.model_construct( + message=AssistantPromptMessage.model_construct(content=[object(), "ok"]) + ), ) + _set_method(pipeline._message_cycle_manager, "get_message_event_type", lambda message_id: StreamEvent.MESSAGE) + _set_queue_events(pipeline, [_queue_message(QueueLLMChunkEvent.model_construct(chunk=chunk))]) responses = list(pipeline._process_stream_response(publisher=None)) - assert responses == ["ok"] + assert len(responses) == 1 + assert isinstance(responses[0], MessageStreamResponse) + assert responses[0].answer == "ok" + assert pipeline._task_state.llm_result.message.content == "ok" - def test_process_stream_response_reaches_post_loop_branch_with_thread_reference(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + def test_process_stream_response_skips_none_chunk_content(self): + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, ) - pipeline._conversation_name_generate_thread = object() - pipeline.queue_manager.listen = lambda: iter([]) + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=None)), + ) + pipeline._message_cycle_manager.message_to_stream_response = Mock() + _set_queue_events(pipeline, [_queue_message(QueueLLMChunkEvent(chunk=chunk))]) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == [] + pipeline._message_cycle_manager.message_to_stream_response.assert_not_called() + assert pipeline._task_state.llm_result.message.content == "" + + def test_process_stream_response_reaches_post_loop_branch_with_thread_reference(self): + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=_FakeQueueManager(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._conversation_name_generate_thread = Thread() + _set_queue_events(pipeline, []) assert list(pipeline._process_stream_response(publisher=None)) == [] def test_save_message_persists_fields_and_emits_trace(self, monkeypatch: pytest.MonkeyPatch): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() application_generate_entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) application_generate_entity.extras = {"trace_session_id": "session-1"} pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=application_generate_entity, - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, ) pipeline.start_at = 10.0 - pipeline._model_config = SimpleNamespace(mode="chat") + _set_method(pipeline, "_model_config", _ModelConfigMode(mode="chat")) pipeline._task_state.llm_result.prompt_messages = [AssistantPromptMessage(content="prompt")] pipeline._task_state.llm_result.message = AssistantPromptMessage(content=" {{name}} hello ") pipeline._task_state.llm_result.usage = LLMUsage.from_metadata( {"prompt_tokens": 3, "completion_tokens": 5, "total_price": "1.23"} ) - message_obj = SimpleNamespace(id="msg") - conversation_obj = SimpleNamespace(id="conv") + message_obj = _make_message() + conversation_obj = _make_conversation(AppMode.CHAT) session = Mock() session.scalar.side_effect = [message_obj, conversation_obj] - trace_manager = SimpleNamespace(add_trace_task=Mock()) + trace_manager_double = _TraceManagerDouble() + trace_manager = cast(TraceQueueManager, trace_manager_double) sent_payloads: list[tuple[tuple[object, ...], dict[str, object]]] = [] monkeypatch.setattr( @@ -949,8 +1199,8 @@ class TestEasyUiBasedGenerateTaskPipeline: assert message_obj.message == "serialized-prompt" assert message_obj.answer == "hello" assert message_obj.provider_response_latency == 5.0 - trace_manager.add_trace_task.assert_called_once() - trace_task = trace_manager.add_trace_task.call_args.args[0] + trace_manager_double.add_trace_task.assert_called_once() + trace_task = trace_manager_double.add_trace_task.call_args.args[0] assert trace_task.trace_type == TraceTaskName.MESSAGE_TRACE assert trace_task.conversation_id == "conv" assert trace_task.message_id == "msg" @@ -958,11 +1208,11 @@ class TestEasyUiBasedGenerateTaskPipeline: assert len(sent_payloads) == 1 def test_save_message_raises_when_message_not_found(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, @@ -974,27 +1224,27 @@ class TestEasyUiBasedGenerateTaskPipeline: pipeline._save_message(session=session) def test_save_message_raises_when_conversation_not_found(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, ) session = Mock() - session.scalar.side_effect = [SimpleNamespace(id="msg"), None] + session.scalar.side_effect = [_make_message(), None] with pytest.raises(ValueError, match="Conversation conv not found"): pipeline._save_message(session=session) def test_message_end_to_stream_response_includes_usage_metadata(self, monkeypatch: pytest.MonkeyPatch): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, @@ -1021,20 +1271,21 @@ class TestEasyUiBasedGenerateTaskPipeline: monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) monkeypatch.setattr( "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", - SimpleNamespace(engine=object()), + _FakeDb(), ) response = pipeline._message_end_to_stream_response() assert response.id == "msg" - assert response.metadata["usage"]["prompt_tokens"] == 1 + usage_metadata = cast(dict[str, object], response.metadata["usage"]) + assert usage_metadata["prompt_tokens"] == 1 def test_record_files_returns_none_when_message_has_no_files(self, monkeypatch: pytest.MonkeyPatch): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, @@ -1060,7 +1311,7 @@ class TestEasyUiBasedGenerateTaskPipeline: monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) monkeypatch.setattr( "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", - SimpleNamespace(engine=object()), + _FakeDb(), ) response = pipeline._message_end_to_stream_response() @@ -1068,39 +1319,36 @@ class TestEasyUiBasedGenerateTaskPipeline: assert response.files is None def test_record_files_handles_local_fallback_and_tool_url_variants(self, monkeypatch: pytest.MonkeyPatch): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=False, ) message_files = [ - SimpleNamespace( - id="mf-local-fallback", - message_id="msg", + _message_file( + file_id="mf-local-fallback", transfer_method=FileTransferMethod.LOCAL_FILE, url="", upload_file_id="upload-missing", - type="file", + file_type=FileType.CUSTOM, ), - SimpleNamespace( - id="mf-tool-http", - message_id="msg", + _message_file( + file_id="mf-tool-http", transfer_method=FileTransferMethod.TOOL_FILE, url="http://cdn.example.com/file.txt?x=1", upload_file_id=None, - type="file", + file_type=FileType.CUSTOM, ), - SimpleNamespace( - id="mf-tool-noext", - message_id="msg", + _message_file( + file_id="mf-tool-noext", transfer_method=FileTransferMethod.TOOL_FILE, url="tool/path/toolid", upload_file_id=None, - type="file", + file_type=FileType.CUSTOM, ), ] @@ -1128,7 +1376,7 @@ class TestEasyUiBasedGenerateTaskPipeline: monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) monkeypatch.setattr( "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", - SimpleNamespace(engine=object()), + _FakeDb(), ) monkeypatch.setattr( "core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url", @@ -1148,11 +1396,11 @@ class TestEasyUiBasedGenerateTaskPipeline: assert files[2]["extension"] == ".bin" def test_agent_message_to_stream_response_builds_payload(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, @@ -1164,11 +1412,11 @@ class TestEasyUiBasedGenerateTaskPipeline: assert response.answer == "hello" def test_agent_thought_to_stream_response_returns_none_when_not_found(self, monkeypatch: pytest.MonkeyPatch): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, @@ -1190,7 +1438,7 @@ class TestEasyUiBasedGenerateTaskPipeline: monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) monkeypatch.setattr( "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", - SimpleNamespace(engine=object()), + _FakeDb(), ) response = pipeline._agent_thought_to_stream_response(QueueAgentThoughtEvent(agent_thought_id="missing")) @@ -1198,11 +1446,11 @@ class TestEasyUiBasedGenerateTaskPipeline: assert response is None def test_handle_output_moderation_chunk_appends_token_when_not_directing(self): - conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) - message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + conversation = _make_conversation(AppMode.CHAT) + message = _make_message() pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), - queue_manager=SimpleNamespace(), + queue_manager=_FakeQueueManager(), conversation=conversation, message=message, stream=True, @@ -1216,7 +1464,7 @@ class TestEasyUiBasedGenerateTaskPipeline: def append_new_token(self, text): appended_tokens.append(text) - pipeline.output_moderation_handler = _Moderation() + _set_method(pipeline, "output_moderation_handler", _Moderation()) result = pipeline._handle_output_moderation_chunk("next-token") diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 40885cfed2e..45d6fc1cd07 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -4,9 +4,10 @@ import io import os import tempfile from collections import UserDict +from collections.abc import Generator from pathlib import Path from types import SimpleNamespace -from typing import override +from typing import Protocol, cast, override from unittest.mock import MagicMock import pytest @@ -18,6 +19,14 @@ import core.rag.extractor.word_extractor as we from core.rag.extractor.word_extractor import WordExtractor +class _TextOxmlElement(Protocol): + text: str | None + + +def _set_oxml_text(element: object, text: str) -> None: + cast(_TextOxmlElement, element).text = text + + def _generate_table_with_merged_cells(): doc = Document() @@ -190,8 +199,8 @@ def test_extract_images_from_docx_uses_internal_files_url(): from configs import dify_config # Mock the configuration values - original_files_url = getattr(dify_config, "FILES_URL", None) - original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None) + original_files_url = dify_config.FILES_URL + original_internal_files_url = dify_config.INTERNAL_FILES_URL try: # Set both URLs - INTERNAL should take precedence @@ -233,7 +242,7 @@ def test_extract_hyperlinks(monkeypatch: pytest.MonkeyPatch): new_run = OxmlElement("w:r") t = OxmlElement("w:t") - t.text = "Dify" + _set_oxml_text(t, "Dify") new_run.append(t) hyperlink.append(new_run) p._p.append(hyperlink) @@ -286,7 +295,7 @@ def test_extract_legacy_hyperlinks(monkeypatch: pytest.MonkeyPatch): run2 = OxmlElement("w:r") instrText = OxmlElement("w:instrText") - instrText.text = ' HYPERLINK "http://example.com" ' + _set_oxml_text(instrText, ' HYPERLINK "http://example.com" ') run2.append(instrText) p._p.append(run2) @@ -298,7 +307,7 @@ def test_extract_legacy_hyperlinks(monkeypatch: pytest.MonkeyPatch): run4 = OxmlElement("w:r") t4 = OxmlElement("w:t") - t4.text = "Example" + _set_oxml_text(t4, "Example") run4.append(t4) p._p.append(run4) @@ -380,20 +389,27 @@ def test_close_is_idempotent(): extractor.temp_file.close.assert_called_once() -async def _async_close() -> None: - return None - - def test_close_closes_awaitable_close_result(): + class FakeAwaitable: + closed: bool = False + + def __await__(self) -> Generator[None, None, None]: + if False: + yield None + return None + + def close(self) -> None: + self.closed = True + extractor = object.__new__(WordExtractor) extractor._closed = False extractor.temp_file = MagicMock() - close_result = _async_close() + close_result = FakeAwaitable() extractor.temp_file.close = MagicMock(return_value=close_result) extractor.close() - assert close_result.cr_frame is None + assert close_result.closed is True extractor.temp_file.close.assert_called_once() @@ -506,6 +522,32 @@ def test_table_to_markdown_and_parse_helpers(monkeypatch: pytest.MonkeyPatch): assert extractor._parse_cell(cell, image_map) == "EXT-IMGINT-IMGplain" +def test_parse_docx_reads_real_paragraph_table_order(monkeypatch: pytest.MonkeyPatch): + doc = Document() + doc.add_paragraph("Before table") + table = doc.add_table(rows=2, cols=2) + table.cell(0, 0).text = "Header A" + table.cell(0, 1).text = "Header B" + table.cell(1, 0).text = "Cell A" + table.cell(1, 1).text = "Cell B" + doc.add_paragraph("After table") + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + extractor = object.__new__(WordExtractor) + monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: {}) + + try: + assert extractor.parse_docx(tmp_path) == ( + "Before table\n| Header A | Header B |\n| --- | --- |\n| Cell A | Cell B |\nAfter table" + ) + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monkeypatch: pytest.MonkeyPatch): extractor = object.__new__(WordExtractor) @@ -620,8 +662,15 @@ def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monke self.element = element self.text = getattr(element, "text", "") - paragraph_main = SimpleNamespace( - _element=[ + class FakeParagraph: + def __init__(self, children): + self._element = children + + class FakeTable: + rows: list[object] = [] + + paragraph_main = FakeParagraph( + [ FakeChild( qn("w:r"), text="run-text", @@ -646,17 +695,16 @@ def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monke ), ] ) - paragraph_empty = SimpleNamespace(_element=[FakeChild(qn("w:r"), text=" ")]) + paragraph_empty = FakeParagraph([FakeChild(qn("w:r"), text=" ")]) + table = FakeTable() fake_doc = SimpleNamespace( part=SimpleNamespace(rels=rels, related_parts={int_embed_id: internal_part}), - paragraphs=[paragraph_main, paragraph_empty], - tables=[SimpleNamespace(rows=[])], - element=SimpleNamespace( - body=[SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:tbl")] - ), + iter_inner_content=lambda: iter([paragraph_main, paragraph_empty, table]), ) + monkeypatch.setattr(we, "Paragraph", FakeParagraph) + monkeypatch.setattr(we, "Table", FakeTable) monkeypatch.setattr(we, "DocxDocument", lambda _: fake_doc) monkeypatch.setattr(we, "Run", FakeRun) monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: image_map) @@ -688,7 +736,7 @@ def test_parse_cell_paragraph_hyperlink_in_table_cell_http(): run_elem = OxmlElement("w:r") t = OxmlElement("w:t") - t.text = "Dify" + _set_oxml_text(t, "Dify") run_elem.append(t) hyperlink.append(run_elem) p._p.append(hyperlink) @@ -728,7 +776,7 @@ def test_parse_cell_paragraph_hyperlink_in_table_cell_mailto(): run_elem = OxmlElement("w:r") t = OxmlElement("w:t") - t.text = "john@test.com" + _set_oxml_text(t, "john@test.com") run_elem.append(t) hyperlink.append(run_elem) p._p.append(hyperlink) diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index 4ba4d54fa09..4368e9cddc0 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -234,20 +234,6 @@ class TestParagraphIndexProcessor: mock_keyword_cls.return_value.delete_by_ids.assert_called_once_with(["node-2"]) - def test_retrieve_filters_by_threshold(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: - accepted = SimpleNamespace(page_content="keep", metadata={"source": "a"}, score=0.9) - rejected = SimpleNamespace(page_content="drop", metadata={"source": "b"}, score=0.1) - - with patch( - "core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve" - ) as mock_retrieve: - mock_retrieve.return_value = [accepted, rejected] - reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} - docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model) - - assert len(docs) == 1 - assert docs[0].metadata["score"] == 0.9 - def test_index_list_chunks_high_quality( self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index 8ef0e046ef6..7d339a7701f 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.entities.knowledge_entities import PreviewDetail -from core.rag.entities import ParentMode +from core.rag.entities import ParentMode, Rule, Segmentation from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor from core.rag.models.document import AttachmentDocument, ChildDocument, Document @@ -293,29 +293,14 @@ class TestParentChildIndexProcessor: mock_summary.assert_called_once_with(dataset, None) - def test_retrieve_filters_by_score_threshold(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: - ok_result = SimpleNamespace(page_content="keep", metadata={"m": 1}, score=0.8) - low_result = SimpleNamespace(page_content="drop", metadata={"m": 2}, score=0.2) - - with patch( - "core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve" - ) as mock_retrieve: - mock_retrieve.return_value = [ok_result, low_result] - reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} - docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, reranking_model) - - assert len(docs) == 1 - assert docs[0].page_content == "keep" - assert docs[0].metadata["score"] == 0.8 - def test_split_child_nodes_requires_subchunk_segmentation(self, processor: ParentChildIndexProcessor) -> None: - rules = SimpleNamespace(subchunk_segmentation=None) + rules = Rule(subchunk_segmentation=None) with pytest.raises(ValueError, match="No subchunk segmentation found"): processor._split_child_nodes(Document(page_content="parent", metadata={}), rules, "custom", None) def test_split_child_nodes_generates_child_documents(self, processor: ParentChildIndexProcessor) -> None: - rules = SimpleNamespace(subchunk_segmentation=self._segmentation()) + rules = Rule(subchunk_segmentation=Segmentation(max_tokens=200, chunk_overlap=10, separator="\n")) splitter = Mock() splitter.split_documents.return_value = [ Document(page_content=".child-1", metadata={}), diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 1f74ccb4387..30600e64651 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -258,19 +258,6 @@ class TestQAIndexProcessor: mock_summary.assert_called_once_with(dataset, None) vector.delete.assert_called_once() - def test_retrieve_filters_by_score_threshold(self, processor: QAIndexProcessor, dataset: Mock) -> None: - result_ok = SimpleNamespace(page_content="accepted", metadata={"source": "a"}, score=0.9) - result_low = SimpleNamespace(page_content="rejected", metadata={"source": "b"}, score=0.1) - - with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve: - mock_retrieve.return_value = [result_ok, result_low] - reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} - docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model) - - assert len(docs) == 1 - assert docs[0].page_content == "accepted" - assert docs[0].metadata["score"] == 0.9 - def test_index_adds_documents_and_vectors_for_high_quality( self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: @@ -331,7 +318,7 @@ class TestQAIndexProcessor: def test_generate_summary_preview_returns_input(self, processor: QAIndexProcessor) -> None: preview_items = [PreviewDetail(content="Q1")] - assert processor.generate_summary_preview("tenant-1", preview_items, {}) is preview_items + assert processor.generate_summary_preview("tenant-1", preview_items, {"enable": False}) is preview_items def test_format_qa_document_ignores_blank_text(self, processor: QAIndexProcessor, fake_flask_app) -> None: all_qa_documents: list[Document] = [] diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py index 21118cc688d..fe1109db301 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py @@ -9,7 +9,6 @@ from core.entities.knowledge_entities import PreviewDetail from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import AttachmentDocument, Document -from core.rag.retrieval.retrieval_methods import RetrievalMethod class _ForwardingBaseIndexProcessor(BaseIndexProcessor): @@ -52,17 +51,6 @@ class _ForwardingBaseIndexProcessor(BaseIndexProcessor): def format_preview(self, chunks): return super().format_preview(chunks) - @override - def retrieve(self, retrieval_method, query, dataset, top_k, score_threshold, reranking_model): - return super().retrieve( - retrieval_method=retrieval_method, - query=query, - dataset=dataset, - top_k=top_k, - score_threshold=score_threshold, - reranking_model=reranking_model, - ) - class TestBaseIndexProcessor: @pytest.fixture @@ -75,7 +63,7 @@ class TestBaseIndexProcessor: with pytest.raises(NotImplementedError): processor.transform([]) with pytest.raises(NotImplementedError): - processor.generate_summary_preview("tenant", [PreviewDetail(content="c")], {}) + processor.generate_summary_preview("tenant", [PreviewDetail(content="c")], {"enable": False}) with pytest.raises(NotImplementedError): processor.load(Mock(), []) with pytest.raises(NotImplementedError): @@ -84,8 +72,6 @@ class TestBaseIndexProcessor: processor.index(Mock(), Mock(), {}) with pytest.raises(NotImplementedError): processor.format_preview([]) - with pytest.raises(NotImplementedError): - processor.retrieve(RetrievalMethod.SEMANTIC_SEARCH, "q", Mock(), 3, 0.5, {}) def test_get_splitter_validates_custom_length(self, processor: _ForwardingBaseIndexProcessor) -> None: with patch( diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 4fbab8e56a1..46cf0f7ac49 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -1,7 +1,8 @@ import threading +from collections.abc import Generator from contextlib import contextmanager, nullcontext from types import SimpleNamespace -from typing import Any +from typing import Any, cast from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 @@ -86,6 +87,14 @@ def create_mock_document( ) +def _dataset(**values: object) -> Dataset: + return cast(Dataset, SimpleNamespace(**values)) + + +def _metadata_condition() -> AppMetadataFilteringCondition: + return AppMetadataFilteringCondition(logical_operator="and", conditions=[]) + + def create_side_effect_for_search(documents: list[Document]): """ Create a side effect function for mocking search methods. @@ -2101,6 +2110,7 @@ class TestDocumentModel: doc = Document(page_content="Test content", vector=vector) assert doc.vector == vector + assert doc.vector is not None assert len(doc.vector) == 5 def test_document_with_external_provider(self): @@ -2914,14 +2924,14 @@ class TestProcessMetadataFilterFunc: return mock_string_access elif name in ["year", "price", "rating"]: return mock_float_access + elif name == "description": + return mock_null_access else: return mock_string_access mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect) mock_metadata_field.as_string.return_value = mock_string_access mock_metadata_field.as_float.return_value = mock_float_access - mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_ - mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot return mock_metadata_field @@ -3933,11 +3943,19 @@ class TestDatasetRetrievalAdditionalHelpers: usage=None, ), ) - text, returned_usage = retrieval._handle_invoke_result(iter([chunk_1, chunk_2])) + + def _chunks() -> Generator[Any]: + yield chunk_1 + yield chunk_2 + + text, returned_usage = retrieval._handle_invoke_result(_chunks()) assert text == "hello world" assert returned_usage == usage - text_empty, usage_empty = retrieval._handle_invoke_result(iter([])) + def _empty_chunks() -> Generator[Any]: + yield from () + + text_empty, usage_empty = retrieval._handle_invoke_result(_empty_chunks()) assert text_empty == "" assert usage_empty == LLMUsage.empty_usage() @@ -4176,7 +4194,9 @@ class TestDatasetRetrievalAdditionalHelpers: ) assert mapping == {"d1": ["doc-1"]} assert condition is not None - assert condition.conditions[0].value == "Alice" + assert condition.conditions + first_condition = condition.conditions[0] + assert first_condition.value == "Alice" with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result): with pytest.raises(ValueError, match="Invalid metadata filtering mode"): @@ -4666,7 +4686,7 @@ class TestSingleAndMultipleRetrieveCoverage: return DatasetRetrieval() def test_single_retrieve_external_path(self, retrieval: DatasetRetrieval) -> None: - dataset = SimpleNamespace( + dataset = _dataset( id="ds-1", name="External DS", description=None, @@ -4711,7 +4731,7 @@ class TestSingleAndMultipleRetrieveCoverage: assert retrieval.llm_usage.total_tokens == 2 def test_single_retrieve_dify_path_and_filters(self, retrieval: DatasetRetrieval) -> None: - dataset = SimpleNamespace( + dataset = _dataset( id="ds-1", name="Internal DS", description="dataset desc", @@ -4755,7 +4775,7 @@ class TestSingleAndMultipleRetrieveCoverage: model_config=Mock(), planning_strategy=PlanningStrategy.ROUTER, metadata_filter_document_ids={"ds-1": ["doc-1"]}, - metadata_condition=SimpleNamespace(), + metadata_condition=_metadata_condition(), ) assert results == [result_doc] @@ -4772,7 +4792,7 @@ class TestSingleAndMultipleRetrieveCoverage: user_from="workflow", query="python", available_datasets=[ - SimpleNamespace(id="ds-1", name="DS", description=None), + _dataset(id="ds-1", name="DS", description=None), ], model_instance=Mock(), model_config=Mock(), @@ -4781,7 +4801,7 @@ class TestSingleAndMultipleRetrieveCoverage: assert results == [] def test_single_retrieve_respects_metadata_filter_shortcuts(self, retrieval: DatasetRetrieval) -> None: - dataset = SimpleNamespace( + dataset = _dataset( id="ds-1", name="Internal DS", description="desc", @@ -4806,7 +4826,7 @@ class TestSingleAndMultipleRetrieveCoverage: model_config=Mock(), planning_strategy=PlanningStrategy.REACT_ROUTER, metadata_filter_document_ids=None, - metadata_condition=SimpleNamespace(), + metadata_condition=_metadata_condition(), ) missing_doc_ids = retrieval.single_retrieve( app_id="app-1", @@ -4841,8 +4861,8 @@ class TestSingleAndMultipleRetrieveCoverage: ) mixed = [ - SimpleNamespace(id="d1", indexing_technique="high_quality"), - SimpleNamespace(id="d2", indexing_technique="economy"), + _dataset(id="d1", indexing_technique="high_quality"), + _dataset(id="d2", indexing_technique="economy"), ] with pytest.raises(ValueError, match="different indexing technique"): retrieval.multiple_retrieve( @@ -4859,13 +4879,13 @@ class TestSingleAndMultipleRetrieveCoverage: ) high_quality_mismatch = [ - SimpleNamespace( + _dataset( id="d1", indexing_technique="high_quality", embedding_model="model-a", embedding_model_provider="provider-a", ), - SimpleNamespace( + _dataset( id="d2", indexing_technique="high_quality", embedding_model="model-b", @@ -4888,13 +4908,13 @@ class TestSingleAndMultipleRetrieveCoverage: def test_multiple_retrieve_threads_and_dedup(self, retrieval: DatasetRetrieval) -> None: datasets = [ - SimpleNamespace( + _dataset( id="d1", indexing_technique="high_quality", embedding_model="model-a", embedding_model_provider="provider-a", ), - SimpleNamespace( + _dataset( id="d2", indexing_technique="high_quality", embedding_model="model-a", @@ -4956,7 +4976,7 @@ class TestSingleAndMultipleRetrieveCoverage: def test_multiple_retrieve_propagates_thread_exception(self, retrieval: DatasetRetrieval) -> None: datasets = [ - SimpleNamespace( + _dataset( id="d1", indexing_technique="high_quality", embedding_model="model-a",