chore(api): Fix several typing errors (#37237)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
chariri 2026-06-12 10:36:48 +09:00 committed by GitHub
parent 99351d2f98
commit b61d39ae2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 777 additions and 538 deletions

View File

@ -24,14 +24,19 @@ from models.model import Message
logger = logging.getLogger(__name__)
class BasedGenerateTaskPipeline:
class BasedGenerateTaskPipeline[AppGenerateEntityT: AppGenerateEntity]:
"""
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
The type parameter preserves the concrete application generate entity for
subclasses after the shared initializer stores it on ``_application_generate_entity``.
"""
_application_generate_entity: AppGenerateEntityT
def __init__(
self,
application_generate_entity: AppGenerateEntity,
application_generate_entity: AppGenerateEntityT,
queue_manager: AppQueueManager,
stream: bool,
):

View File

@ -65,19 +65,20 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought, Me
logger = logging.getLogger(__name__)
type EasyUIAppGenerateEntity = ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline[EasyUIAppGenerateEntity]):
"""
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: EasyUITaskState
_application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity
_precomputed_event_type: StreamEvent | None = None
def __init__(
self,
application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity,
application_generate_entity: EasyUIAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
@ -310,12 +311,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
yield response
case QueueLLMChunkEvent() | QueueAgentMessageEvent():
chunk = event.chunk
delta_text = chunk.delta.message.content
if delta_text is None:
delta_content = chunk.delta.message.content
if delta_content is None:
continue
if isinstance(chunk.delta.message.content, list):
if isinstance(delta_content, list):
# EasyUI streams text only; structured multimodal chunks contribute their text parts.
delta_text = ""
for content in chunk.delta.message.content:
for content in delta_content:
logger.debug(
"The content type %s in LLM chunk delta message content.: %r", type(content), content
)
@ -331,17 +333,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
content,
)
continue
else:
delta_text = delta_content
if not self._task_state.llm_result.prompt_messages:
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
if should_direct_answer:
continue
current_content = cast(str, self._task_state.llm_result.message.content)
current_content += cast(str, delta_text)
current_content += delta_text
self._task_state.llm_result.message.content = current_content
match event:
@ -352,13 +356,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message_id=self._message_id
)
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
answer=delta_text,
message_id=self._message_id,
event_type=self._precomputed_event_type,
)
case _:
yield self._agent_message_to_stream_response(
answer=cast(str, delta_text),
answer=delta_text,
message_id=self._message_id,
)
case QueueMessageReplaceEvent():
@ -389,9 +393,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if not conversation:
raise ValueError(f"Conversation {self._conversation_id} not found")
message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
saved_prompt = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode, self._task_state.llm_result.prompt_messages
)
object.__setattr__(message, "message", saved_prompt)
message.message_tokens = usage.prompt_tokens
message.message_unit_price = usage.prompt_unit_price
message.message_price_unit = usage.prompt_price_unit

View File

@ -107,7 +107,7 @@ class LLMGenerator:
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
prompts = [UserPromptMessage(content=prompt)]
prompts: list[PromptMessage] = [UserPromptMessage(content=prompt)]
with measure_time() as timer:
response: LLMResult = model_instance.invoke_llm(
@ -201,11 +201,13 @@ class LLMGenerator:
except InvokeAuthorizationError:
return []
prompt_messages = [UserPromptMessage(content=prompt)]
prompt_messages: list[PromptMessage] = [UserPromptMessage(content=prompt)]
questions: Sequence[str] = []
try:
model_parameters: dict[str, object]
stop: list[str]
configured_completion_params = configured_model.get("completion_params")
if use_configured_model and isinstance(configured_completion_params, dict):
model_parameters, stop = _normalize_completion_params(configured_completion_params)
@ -253,7 +255,7 @@ class LLMGenerator:
remove_template_variables=False,
)
prompt_messages = [UserPromptMessage(content=prompt_generate)]
no_variable_prompt_messages: list[PromptMessage] = [UserPromptMessage(content=prompt_generate)]
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
@ -266,7 +268,7 @@ class LLMGenerator:
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
prompt_messages=list(no_variable_prompt_messages), model_parameters=model_parameters, stream=False
)
rule_config["prompt"] = response.message.get_text_content()
@ -299,7 +301,7 @@ class LLMGenerator:
},
remove_template_variables=False,
)
prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)]
prompt_generate_messages: list[PromptMessage] = [UserPromptMessage(content=prompt_generate_prompt)]
# get model instance
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
@ -314,7 +316,7 @@ class LLMGenerator:
try:
# the first step to generate the task prompt
prompt_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
prompt_messages=list(prompt_generate_messages), model_parameters=model_parameters, stream=False
)
except InvokeError as e:
error = str(e)
@ -331,7 +333,7 @@ class LLMGenerator:
},
remove_template_variables=False,
)
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
parameter_messages: list[PromptMessage] = [UserPromptMessage(content=parameter_generate_prompt)]
# the second step to generate the task_parameter and task_statement
statement_generate_prompt = statement_template.format(
@ -341,7 +343,7 @@ class LLMGenerator:
},
remove_template_variables=False,
)
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
statement_messages: list[PromptMessage] = [UserPromptMessage(content=statement_generate_prompt)]
try:
parameter_content: LLMResult = model_instance.invoke_llm(
@ -397,7 +399,7 @@ class LLMGenerator:
model=args.model_config_data.name,
)
prompt_messages = [UserPromptMessage(content=prompt)]
prompt_messages: list[PromptMessage] = [UserPromptMessage(content=prompt)]
model_parameters = args.model_config_data.completion_params
try:
response: LLMResult = model_instance.invoke_llm(
@ -455,7 +457,7 @@ class LLMGenerator:
model=args.model_config_data.name,
)
prompt_messages = [
prompt_messages: list[PromptMessage] = [
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
UserPromptMessage(content=args.instruction),
]
@ -634,7 +636,7 @@ class LLMGenerator:
system_prompt = LLM_MODIFY_CODE_SYSTEM
case _:
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
prompt_messages = [
prompt_messages: list[PromptMessage] = [
SystemPromptMessage(content=system_prompt),
UserPromptMessage(
content=json.dumps(

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Any, cast
from typing import NotRequired, TypedDict, cast
from core.prompt.simple_prompt_transform import ModelMode
from graphon.model_runtime.entities import (
@ -13,19 +13,46 @@ from graphon.model_runtime.entities import (
)
class SavedPromptFile(TypedDict):
type: str
data: str
detail: NotRequired[str]
format: NotRequired[str]
class SavedPromptToolCallFunction(TypedDict):
name: str
arguments: str
class SavedPromptToolCall(TypedDict):
id: str
type: str
function: SavedPromptToolCallFunction
class SavedPrompt(TypedDict):
role: str
text: str
files: NotRequired[list[SavedPromptFile]]
tool_calls: NotRequired[list[SavedPromptToolCall]]
class PromptMessageUtil:
@staticmethod
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]):
def prompt_messages_to_prompt_for_saving(
model_mode: str, prompt_messages: Sequence[PromptMessage]
) -> list[SavedPrompt]:
"""
Prompt messages to prompt for saving.
:param model_mode: model mode
:param prompt_messages: prompt messages
:return:
"""
prompts = []
prompts: list[SavedPrompt] = []
if model_mode == ModelMode.CHAT:
tool_calls = []
for prompt_message in prompt_messages:
tool_calls: list[SavedPromptToolCall] = []
if prompt_message.role == PromptMessageRole.USER:
role = "user"
elif prompt_message.role == PromptMessageRole.ASSISTANT:
@ -50,7 +77,7 @@ class PromptMessageUtil:
continue
text = ""
files = []
files: list[SavedPromptFile] = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
match content:
@ -77,7 +104,7 @@ class PromptMessageUtil:
else:
text = cast(str, prompt_message.content)
prompt = {"role": role, "text": text, "files": files}
prompt: SavedPrompt = {"role": role, "text": text, "files": files}
if tool_calls:
prompt["tool_calls"] = tool_calls
@ -86,14 +113,14 @@ class PromptMessageUtil:
else:
prompt_message = prompt_messages[0]
text = ""
files = []
prompt_files: list[SavedPromptFile] = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
files.append(
prompt_files.append(
{
"type": "image",
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
@ -103,13 +130,13 @@ class PromptMessageUtil:
else:
text = cast(str, prompt_message.content)
params: dict[str, Any] = {
params: SavedPrompt = {
"role": "user",
"text": text,
}
if files:
params["files"] = files
if prompt_files:
params["files"] = prompt_files
prompts.append(params)

View File

@ -105,6 +105,8 @@ class WaterCrawlProvider:
def scrape_url(self, url: str) -> WatercrawlDocumentData:
response = self.client.scrape_url(url=url, sync=True, prefetched=True)
if not isinstance(response, dict):
raise ValueError("Invalid scrape response. Expected a JSON dictionary.")
return self._structure_data(response)
def _structure_data(self, result_object: dict[str, Any]) -> WatercrawlDocumentData:

View File

@ -15,6 +15,8 @@ from urllib.parse import urlparse
from docx import Document as DocxDocument
from docx.oxml.ns import qn
from docx.table import Table
from docx.text.paragraph import Paragraph
from docx.text.run import Run
from configs import dify_config
@ -286,10 +288,10 @@ class WordExtractor(BaseExtractor):
return "".join(paragraph_content).strip()
def parse_docx(self, docx_path):
def parse_docx(self, docx_path: str) -> str:
doc = DocxDocument(docx_path)
content = []
content: list[str] = []
image_map = self._extract_images_from_docx(doc)
@ -445,18 +447,11 @@ class WordExtractor(BaseExtractor):
process_hyperlink(child, paragraph_content)
return "".join(paragraph_content) if paragraph_content else ""
paragraphs = doc.paragraphs.copy()
tables = doc.tables.copy()
for element in doc.element.body:
if hasattr(element, "tag"):
if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph
para = paragraphs.pop(0)
parsed_paragraph = parse_paragraph(para)
if parsed_paragraph.strip():
content.append(parsed_paragraph)
else:
content.append("\n")
elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table
table = tables.pop(0)
content.append(self._table_to_markdown(table, image_map))
for block in doc.iter_inner_content():
match block:
case Paragraph():
parsed_paragraph = parse_paragraph(block)
content.append(parsed_paragraph if parsed_paragraph.strip() else "\n")
case Table():
content.append(self._table_to_markdown(block, image_map))
return "\n".join(content)

View File

@ -16,11 +16,9 @@ from sqlalchemy import select
from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
from core.file import remote_fetcher
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import AttachmentDocument, Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
@ -99,18 +97,6 @@ class BaseIndexProcessor(ABC):
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
raise NotImplementedError
@abstractmethod
def retrieve(
self,
retrieval_method: RetrievalMethod,
query: str,
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: RerankingModelDict,
) -> list[Document]:
raise NotImplementedError
def _get_splitter(
self,
processing_rule_mode: str,

View File

@ -16,9 +16,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import Rule
@ -28,7 +26,6 @@ from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.workflow.file_reference import build_file_reference
from extensions.ext_database import db
@ -182,35 +179,6 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
else:
keyword.delete()
@override
def retrieve(
self,
retrieval_method: RetrievalMethod,
query: str,
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: RerankingModelDict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
@override
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
documents: list[Any] = []

View File

@ -12,8 +12,6 @@ from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import ParentMode, Rule
@ -23,7 +21,6 @@ from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from libs import helper
from models import Account
@ -223,35 +220,6 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
)
db.session.commit()
@override
def retrieve(
self,
retrieval_method: RetrievalMethod,
query: str,
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: RerankingModelDict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
def _split_child_nodes(
self,
document_node: Document,

View File

@ -15,8 +15,6 @@ from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import Rule
@ -25,7 +23,6 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
@ -187,35 +184,6 @@ class QAIndexProcessor(BaseIndexProcessor):
else:
vector.delete()
@override
def retrieve(
self,
retrieval_method: RetrievalMethod,
query: str,
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: RerankingModelDict,
):
# Set search parameters.
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
@override
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
qa_chunks = QAStructureChunk.model_validate(chunks)

View File

@ -609,7 +609,7 @@ class DatasetRetrieval:
metadata_filter_document_ids: dict[str, list[str]] | None = None,
metadata_condition: MetadataFilteringCondition | None = None,
):
tools = []
tools: list[PromptMessageTool] = []
for dataset in available_datasets:
description = dataset.description
if not description:
@ -1162,7 +1162,7 @@ class DatasetRetrieval:
:param invoke_from: invoke from
:param hit_callback: hit callback
"""
tools = []
tools: list[DatasetRetrieverBaseTool] = []
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id

View File

@ -30,6 +30,11 @@ class CelerySSLOptionsDict(TypedDict):
ssl_keyfile: str | None
class CeleryBeatScheduleEntry(TypedDict):
task: str
schedule: crontab | timedelta
def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
"""Get SSL configuration for Celery broker/backend connections."""
# Only apply SSL if we're using Redis as broker/backend
@ -152,7 +157,7 @@ def init_app(app: DifyApp) -> Celery:
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
# if you add a new task, please add the switch to CeleryScheduleTasksConfig
beat_schedule = {}
beat_schedule: dict[str, CeleryBeatScheduleEntry] = {}
if dify_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK:
imports.append("schedule.clean_embedding_cache_task")
beat_schedule["clean_embedding_cache_task"] = {

View File

@ -4,7 +4,7 @@ from datetime import datetime, timedelta
from typing import Any, cast, override
import mlflow
from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType
from mlflow.entities import Document, LiveSpan, Span, SpanEvent, SpanStatusCode, SpanType
from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey
from mlflow.tracing.fluent import start_span_no_context, update_current_trace
from mlflow.tracing.provider import detach_span_from_context, set_span_in_context
@ -31,6 +31,8 @@ from models.workflow import WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
type SpanAttributes = dict[str, object]
def datetime_to_nanoseconds(dt: datetime | None) -> int | None:
"""Convert datetime to nanosecond timestamp for MLflow API"""
@ -39,6 +41,32 @@ def datetime_to_nanoseconds(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1_000_000_000)
def _start_span_no_context(
*,
name: str,
span_type: str,
parent_span: LiveSpan | None = None,
inputs: object | None = None,
attributes: SpanAttributes | None = None,
start_time_ns: int | None = None,
) -> LiveSpan:
"""Start an MLflow span while preserving structured Dify attributes.
MLflow 3.11 annotates `start_span_no_context(..., attributes=...)` as `dict[str, str]`,
but the implementation immediately calls `LiveSpan.set_attributes(dict[str, Any])`.
`LiveSpan` JSON-serializes arbitrary values before storing them in OpenTelemetry, and
reserved attributes like `mlflow.chat.tokenUsage` are expected to round-trip as dicts.
"""
return start_span_no_context(
name=name,
span_type=span_type,
parent_span=parent_span,
inputs=inputs,
attributes=cast(dict[str, str] | None, attributes),
start_time_ns=start_time_ns,
)
class MLflowDataTrace(BaseTraceInstance):
def __init__(self, config: MLflowConfig | DatabricksConfig):
super().__init__(config)
@ -119,7 +147,7 @@ class MLflowDataTrace(BaseTraceInstance):
if trace_info.query:
workflow_inputs["query"] = trace_info.query
workflow_span = start_span_no_context(
workflow_span = _start_span_no_context(
name=TraceTaskName.WORKFLOW_TRACE.value,
span_type=SpanType.CHAIN,
inputs=workflow_inputs,
@ -139,7 +167,7 @@ class MLflowDataTrace(BaseTraceInstance):
# Create child spans for workflow nodes
for node in self._get_workflow_nodes(trace_info.workflow_run_id):
inputs = None
attributes = {
attributes: SpanAttributes = {
"node_id": node.id,
"node_type": node.node_type,
"status": node.status,
@ -157,7 +185,7 @@ class MLflowDataTrace(BaseTraceInstance):
if not inputs:
inputs = JSON_DICT_ADAPTER.validate_json(node.inputs) if node.inputs else {}
node_span = start_span_no_context(
node_span = _start_span_no_context(
name=node.title,
span_type=self._get_node_span_type(node.node_type),
parent_span=workflow_span,
@ -212,7 +240,7 @@ class MLflowDataTrace(BaseTraceInstance):
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
)
def _parse_llm_inputs_and_attributes(self, node: WorkflowNodeExecutionModel) -> tuple[Any, dict]:
def _parse_llm_inputs_and_attributes(self, node: WorkflowNodeExecutionModel) -> tuple[object, SpanAttributes]:
"""Parse LLM inputs and attributes from LLM workflow node"""
if node.process_data is None:
return {}, {}
@ -266,16 +294,16 @@ class MLflowDataTrace(BaseTraceInstance):
base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
file_list.append(f"{base_url}/{message_file_data.url}")
span = start_span_no_context(
span = _start_span_no_context(
name=TraceTaskName.MESSAGE_TRACE.value,
span_type=SpanType.LLM,
inputs=self._parse_prompts(trace_info.inputs), # type: ignore[arg-type]
inputs=self._parse_prompts(trace_info.inputs),
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"message_id": trace_info.message_id,
"model_provider": trace_info.message_data.model_provider,
"model_id": trace_info.message_data.model_id,
"conversation_mode": trace_info.conversation_mode,
"file_list": file_list, # type: ignore[dict-item]
"file_list": file_list,
"total_price": trace_info.message_data.total_price,
**trace_info.metadata,
},
@ -330,15 +358,15 @@ class MLflowDataTrace(BaseTraceInstance):
return metadata.get("from_account_id") # type: ignore[return-value]
def tool_trace(self, trace_info: ToolTraceInfo):
span = start_span_no_context(
span = _start_span_no_context(
name=trace_info.tool_name,
span_type=SpanType.TOOL,
inputs=trace_info.tool_inputs, # type: ignore[arg-type]
inputs=trace_info.tool_inputs,
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"metadata": trace_info.metadata, # type: ignore[dict-item]
"tool_config": trace_info.tool_config, # type: ignore[dict-item]
"tool_parameters": trace_info.tool_parameters, # type: ignore[dict-item]
"message_id": trace_info.message_id,
"metadata": trace_info.metadata,
"tool_config": trace_info.tool_config,
"tool_parameters": trace_info.tool_parameters,
},
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
)
@ -367,13 +395,13 @@ class MLflowDataTrace(BaseTraceInstance):
return
start_time = trace_info.start_time or trace_info.message_data.created_at
span = start_span_no_context(
span = _start_span_no_context(
name=TraceTaskName.MODERATION_TRACE.value,
span_type=SpanType.TOOL,
inputs=trace_info.inputs or {},
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"metadata": trace_info.metadata, # type: ignore[dict-item]
"message_id": trace_info.message_id,
"metadata": trace_info.metadata,
},
start_time_ns=datetime_to_nanoseconds(start_time),
)
@ -391,13 +419,13 @@ class MLflowDataTrace(BaseTraceInstance):
if trace_info.message_data is None:
return
span = start_span_no_context(
span = _start_span_no_context(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
span_type=SpanType.RETRIEVER,
inputs=trace_info.inputs,
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"metadata": trace_info.metadata, # type: ignore[dict-item]
"message_id": trace_info.message_id,
"metadata": trace_info.metadata,
},
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
)
@ -410,15 +438,15 @@ class MLflowDataTrace(BaseTraceInstance):
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
span = start_span_no_context(
span = _start_span_no_context(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
span_type=SpanType.TOOL,
inputs=trace_info.inputs,
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"model_provider": trace_info.model_provider, # type: ignore[dict-item]
"model_id": trace_info.model_id, # type: ignore[dict-item]
"total_tokens": trace_info.total_tokens or 0, # type: ignore[dict-item]
"message_id": trace_info.message_id,
"model_provider": trace_info.model_provider,
"model_id": trace_info.model_id,
"total_tokens": trace_info.total_tokens or 0,
},
start_time_ns=datetime_to_nanoseconds(start_time),
)
@ -439,11 +467,11 @@ class MLflowDataTrace(BaseTraceInstance):
span.end(outputs=trace_info.suggested_question, end_time_ns=datetime_to_nanoseconds(end_time))
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
span = start_span_no_context(
span = _start_span_no_context(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
span_type=SpanType.CHAIN,
inputs=trace_info.inputs,
attributes={"message_id": trace_info.message_id}, # type: ignore[dict-item]
attributes={"message_id": trace_info.message_id},
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
)
span.end(outputs=trace_info.outputs, end_time_ns=datetime_to_nanoseconds(trace_info.end_time))

View File

@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch
import pytest
from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds
from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
@ -361,6 +362,7 @@ class TestWorkflowTrace:
assert inputs["query"] == "hello"
def test_workflow_with_llm_node(self, trace_instance, mock_tracing, mock_db):
usage = {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
llm_node = _make_node(
node_type=BuiltinNodeTypes.LLM,
process_data=json.dumps(
@ -369,7 +371,7 @@ class TestWorkflowTrace:
"model_name": "gpt-4",
"model_provider": "openai",
"finish_reason": "stop",
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15},
"usage": usage,
}
),
outputs='{"text": "hello world"}',
@ -383,6 +385,14 @@ class TestWorkflowTrace:
trace_instance.workflow_trace(_make_workflow_trace_info())
assert mock_tracing["start"].call_count == 2
node_start_call = mock_tracing["start"].call_args_list[1]
attrs = node_start_call.kwargs["attributes"]
assert attrs[SpanAttributeKey.CHAT_USAGE] == {
TokenUsageKey.INPUT_TOKENS: 5,
TokenUsageKey.OUTPUT_TOKENS: 10,
TokenUsageKey.TOTAL_TOKENS: 15,
}
assert attrs["usage"] == usage
node_span.end.assert_called_once()
workflow_span.end.assert_called_once()
@ -631,6 +641,27 @@ class TestMessageTrace:
assert "http://files.test/path/to/file.png" in attrs["file_list"]
assert "existing_file.txt" in attrs["file_list"]
def test_message_trace_preserves_structured_span_attributes(self, trace_instance, mock_tracing, mock_db):
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
trace_info = _make_message_trace_info(
metadata={
"conversation_id": "c1",
"from_account_id": "a1",
"routing": {"node": "answer", "score": 0.7},
},
file_list=["existing_file.txt"],
)
trace_instance.message_trace(trace_info)
attrs = mock_tracing["start"].call_args.kwargs["attributes"]
assert attrs["message_id"] == "msg-1"
assert attrs["total_price"] == 0.01
assert attrs["routing"] == {"node": "answer", "score": 0.7}
assert attrs["file_list"] == ["existing_file.txt"]
def test_message_trace_file_list_none(self, trace_instance, mock_tracing, mock_db):
span = MagicMock()
mock_tracing["start"].return_value = span

View File

@ -1,9 +1,3 @@
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
core/llm_generator/llm_generator.py
providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
core/prompt/utils/prompt_message_util.py
core/rag/retrieval/dataset_retrieval.py
extensions/ext_celery.py
providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_factory.py
providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py
providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_openapi.py
@ -59,11 +53,6 @@ providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py
providers/vdb/vdb-vikingdb/tests/unit_tests/test_vikingdb_vector.py
providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py
providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py
core/rag/extractor/watercrawl/provider.py
core/rag/extractor/word_extractor.py
core/rag/index_processor/processor/paragraph_index_processor.py
core/rag/index_processor/processor/parent_child_index_processor.py
core/rag/index_processor/processor/qa_index_processor.py
core/tools/mcp_tool/provider.py
core/tools/plugin_tool/provider.py
core/tools/workflow_as_tool/provider.py

View File

@ -4,9 +4,10 @@ import io
import os
import tempfile
from collections import UserDict
from collections.abc import Generator
from pathlib import Path
from types import SimpleNamespace
from typing import override
from typing import Protocol, cast, override
from unittest.mock import MagicMock
import pytest
@ -18,6 +19,14 @@ import core.rag.extractor.word_extractor as we
from core.rag.extractor.word_extractor import WordExtractor
class _TextOxmlElement(Protocol):
text: str | None
def _set_oxml_text(element: object, text: str) -> None:
cast(_TextOxmlElement, element).text = text
def _generate_table_with_merged_cells():
doc = Document()
@ -190,8 +199,8 @@ def test_extract_images_from_docx_uses_internal_files_url():
from configs import dify_config
# Mock the configuration values
original_files_url = getattr(dify_config, "FILES_URL", None)
original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None)
original_files_url = dify_config.FILES_URL
original_internal_files_url = dify_config.INTERNAL_FILES_URL
try:
# Set both URLs - INTERNAL should take precedence
@ -233,7 +242,7 @@ def test_extract_hyperlinks(monkeypatch: pytest.MonkeyPatch):
new_run = OxmlElement("w:r")
t = OxmlElement("w:t")
t.text = "Dify"
_set_oxml_text(t, "Dify")
new_run.append(t)
hyperlink.append(new_run)
p._p.append(hyperlink)
@ -286,7 +295,7 @@ def test_extract_legacy_hyperlinks(monkeypatch: pytest.MonkeyPatch):
run2 = OxmlElement("w:r")
instrText = OxmlElement("w:instrText")
instrText.text = ' HYPERLINK "http://example.com" '
_set_oxml_text(instrText, ' HYPERLINK "http://example.com" ')
run2.append(instrText)
p._p.append(run2)
@ -298,7 +307,7 @@ def test_extract_legacy_hyperlinks(monkeypatch: pytest.MonkeyPatch):
run4 = OxmlElement("w:r")
t4 = OxmlElement("w:t")
t4.text = "Example"
_set_oxml_text(t4, "Example")
run4.append(t4)
p._p.append(run4)
@ -380,20 +389,27 @@ def test_close_is_idempotent():
extractor.temp_file.close.assert_called_once()
async def _async_close() -> None:
return None
def test_close_closes_awaitable_close_result():
class FakeAwaitable:
closed: bool = False
def __await__(self) -> Generator[None, None, None]:
if False:
yield None
return None
def close(self) -> None:
self.closed = True
extractor = object.__new__(WordExtractor)
extractor._closed = False
extractor.temp_file = MagicMock()
close_result = _async_close()
close_result = FakeAwaitable()
extractor.temp_file.close = MagicMock(return_value=close_result)
extractor.close()
assert close_result.cr_frame is None
assert close_result.closed is True
extractor.temp_file.close.assert_called_once()
@ -506,6 +522,32 @@ def test_table_to_markdown_and_parse_helpers(monkeypatch: pytest.MonkeyPatch):
assert extractor._parse_cell(cell, image_map) == "EXT-IMGINT-IMGplain"
def test_parse_docx_reads_real_paragraph_table_order(monkeypatch: pytest.MonkeyPatch):
doc = Document()
doc.add_paragraph("Before table")
table = doc.add_table(rows=2, cols=2)
table.cell(0, 0).text = "Header A"
table.cell(0, 1).text = "Header B"
table.cell(1, 0).text = "Cell A"
table.cell(1, 1).text = "Cell B"
doc.add_paragraph("After table")
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
doc.save(tmp.name)
tmp_path = tmp.name
extractor = object.__new__(WordExtractor)
monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: {})
try:
assert extractor.parse_docx(tmp_path) == (
"Before table\n| Header A | Header B |\n| --- | --- |\n| Cell A | Cell B |\nAfter table"
)
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monkeypatch: pytest.MonkeyPatch):
extractor = object.__new__(WordExtractor)
@ -620,8 +662,15 @@ def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monke
self.element = element
self.text = getattr(element, "text", "")
paragraph_main = SimpleNamespace(
_element=[
class FakeParagraph:
def __init__(self, children):
self._element = children
class FakeTable:
rows: list[object] = []
paragraph_main = FakeParagraph(
[
FakeChild(
qn("w:r"),
text="run-text",
@ -646,17 +695,16 @@ def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monke
),
]
)
paragraph_empty = SimpleNamespace(_element=[FakeChild(qn("w:r"), text=" ")])
paragraph_empty = FakeParagraph([FakeChild(qn("w:r"), text=" ")])
table = FakeTable()
fake_doc = SimpleNamespace(
part=SimpleNamespace(rels=rels, related_parts={int_embed_id: internal_part}),
paragraphs=[paragraph_main, paragraph_empty],
tables=[SimpleNamespace(rows=[])],
element=SimpleNamespace(
body=[SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:tbl")]
),
iter_inner_content=lambda: iter([paragraph_main, paragraph_empty, table]),
)
monkeypatch.setattr(we, "Paragraph", FakeParagraph)
monkeypatch.setattr(we, "Table", FakeTable)
monkeypatch.setattr(we, "DocxDocument", lambda _: fake_doc)
monkeypatch.setattr(we, "Run", FakeRun)
monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: image_map)
@ -688,7 +736,7 @@ def test_parse_cell_paragraph_hyperlink_in_table_cell_http():
run_elem = OxmlElement("w:r")
t = OxmlElement("w:t")
t.text = "Dify"
_set_oxml_text(t, "Dify")
run_elem.append(t)
hyperlink.append(run_elem)
p._p.append(hyperlink)
@ -728,7 +776,7 @@ def test_parse_cell_paragraph_hyperlink_in_table_cell_mailto():
run_elem = OxmlElement("w:r")
t = OxmlElement("w:t")
t.text = "john@test.com"
_set_oxml_text(t, "john@test.com")
run_elem.append(t)
hyperlink.append(run_elem)
p._p.append(hyperlink)

View File

@ -234,20 +234,6 @@ class TestParagraphIndexProcessor:
mock_keyword_cls.return_value.delete_by_ids.assert_called_once_with(["node-2"])
def test_retrieve_filters_by_threshold(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
accepted = SimpleNamespace(page_content="keep", metadata={"source": "a"}, score=0.9)
rejected = SimpleNamespace(page_content="drop", metadata={"source": "b"}, score=0.1)
with patch(
"core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve"
) as mock_retrieve:
mock_retrieve.return_value = [accepted, rejected]
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model)
assert len(docs) == 1
assert docs[0].metadata["score"] == 0.9
def test_index_list_chunks_high_quality(
self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:

View File

@ -4,7 +4,7 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from core.entities.knowledge_entities import PreviewDetail
from core.rag.entities import ParentMode
from core.rag.entities import ParentMode, Rule, Segmentation
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
@ -293,29 +293,14 @@ class TestParentChildIndexProcessor:
mock_summary.assert_called_once_with(dataset, None)
def test_retrieve_filters_by_score_threshold(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
ok_result = SimpleNamespace(page_content="keep", metadata={"m": 1}, score=0.8)
low_result = SimpleNamespace(page_content="drop", metadata={"m": 2}, score=0.2)
with patch(
"core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve"
) as mock_retrieve:
mock_retrieve.return_value = [ok_result, low_result]
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, reranking_model)
assert len(docs) == 1
assert docs[0].page_content == "keep"
assert docs[0].metadata["score"] == 0.8
def test_split_child_nodes_requires_subchunk_segmentation(self, processor: ParentChildIndexProcessor) -> None:
rules = SimpleNamespace(subchunk_segmentation=None)
rules = Rule(subchunk_segmentation=None)
with pytest.raises(ValueError, match="No subchunk segmentation found"):
processor._split_child_nodes(Document(page_content="parent", metadata={}), rules, "custom", None)
def test_split_child_nodes_generates_child_documents(self, processor: ParentChildIndexProcessor) -> None:
rules = SimpleNamespace(subchunk_segmentation=self._segmentation())
rules = Rule(subchunk_segmentation=Segmentation(max_tokens=200, chunk_overlap=10, separator="\n"))
splitter = Mock()
splitter.split_documents.return_value = [
Document(page_content=".child-1", metadata={}),

View File

@ -258,19 +258,6 @@ class TestQAIndexProcessor:
mock_summary.assert_called_once_with(dataset, None)
vector.delete.assert_called_once()
def test_retrieve_filters_by_score_threshold(self, processor: QAIndexProcessor, dataset: Mock) -> None:
result_ok = SimpleNamespace(page_content="accepted", metadata={"source": "a"}, score=0.9)
result_low = SimpleNamespace(page_content="rejected", metadata={"source": "b"}, score=0.1)
with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve:
mock_retrieve.return_value = [result_ok, result_low]
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model)
assert len(docs) == 1
assert docs[0].page_content == "accepted"
assert docs[0].metadata["score"] == 0.9
def test_index_adds_documents_and_vectors_for_high_quality(
self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock
) -> None:
@ -331,7 +318,7 @@ class TestQAIndexProcessor:
def test_generate_summary_preview_returns_input(self, processor: QAIndexProcessor) -> None:
preview_items = [PreviewDetail(content="Q1")]
assert processor.generate_summary_preview("tenant-1", preview_items, {}) is preview_items
assert processor.generate_summary_preview("tenant-1", preview_items, {"enable": False}) is preview_items
def test_format_qa_document_ignores_blank_text(self, processor: QAIndexProcessor, fake_flask_app) -> None:
all_qa_documents: list[Document] = []

View File

@ -9,7 +9,6 @@ from core.entities.knowledge_entities import PreviewDetail
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import AttachmentDocument, Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
class _ForwardingBaseIndexProcessor(BaseIndexProcessor):
@ -52,17 +51,6 @@ class _ForwardingBaseIndexProcessor(BaseIndexProcessor):
def format_preview(self, chunks):
return super().format_preview(chunks)
@override
def retrieve(self, retrieval_method, query, dataset, top_k, score_threshold, reranking_model):
return super().retrieve(
retrieval_method=retrieval_method,
query=query,
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
)
class TestBaseIndexProcessor:
@pytest.fixture
@ -75,7 +63,7 @@ class TestBaseIndexProcessor:
with pytest.raises(NotImplementedError):
processor.transform([])
with pytest.raises(NotImplementedError):
processor.generate_summary_preview("tenant", [PreviewDetail(content="c")], {})
processor.generate_summary_preview("tenant", [PreviewDetail(content="c")], {"enable": False})
with pytest.raises(NotImplementedError):
processor.load(Mock(), [])
with pytest.raises(NotImplementedError):
@ -84,8 +72,6 @@ class TestBaseIndexProcessor:
processor.index(Mock(), Mock(), {})
with pytest.raises(NotImplementedError):
processor.format_preview([])
with pytest.raises(NotImplementedError):
processor.retrieve(RetrievalMethod.SEMANTIC_SEARCH, "q", Mock(), 3, 0.5, {})
def test_get_splitter_validates_custom_length(self, processor: _ForwardingBaseIndexProcessor) -> None:
with patch(

View File

@ -1,7 +1,8 @@
import threading
from collections.abc import Generator
from contextlib import contextmanager, nullcontext
from types import SimpleNamespace
from typing import Any
from typing import Any, cast
from unittest.mock import MagicMock, Mock, patch
from uuid import uuid4
@ -86,6 +87,14 @@ def create_mock_document(
)
def _dataset(**values: object) -> Dataset:
return cast(Dataset, SimpleNamespace(**values))
def _metadata_condition() -> AppMetadataFilteringCondition:
return AppMetadataFilteringCondition(logical_operator="and", conditions=[])
def create_side_effect_for_search(documents: list[Document]):
"""
Create a side effect function for mocking search methods.
@ -2101,6 +2110,7 @@ class TestDocumentModel:
doc = Document(page_content="Test content", vector=vector)
assert doc.vector == vector
assert doc.vector is not None
assert len(doc.vector) == 5
def test_document_with_external_provider(self):
@ -2914,14 +2924,14 @@ class TestProcessMetadataFilterFunc:
return mock_string_access
elif name in ["year", "price", "rating"]:
return mock_float_access
elif name == "description":
return mock_null_access
else:
return mock_string_access
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
mock_metadata_field.as_string.return_value = mock_string_access
mock_metadata_field.as_float.return_value = mock_float_access
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
return mock_metadata_field
@ -3933,11 +3943,19 @@ class TestDatasetRetrievalAdditionalHelpers:
usage=None,
),
)
text, returned_usage = retrieval._handle_invoke_result(iter([chunk_1, chunk_2]))
def _chunks() -> Generator[Any]:
yield chunk_1
yield chunk_2
text, returned_usage = retrieval._handle_invoke_result(_chunks())
assert text == "hello world"
assert returned_usage == usage
text_empty, usage_empty = retrieval._handle_invoke_result(iter([]))
def _empty_chunks() -> Generator[Any]:
yield from ()
text_empty, usage_empty = retrieval._handle_invoke_result(_empty_chunks())
assert text_empty == ""
assert usage_empty == LLMUsage.empty_usage()
@ -4176,7 +4194,9 @@ class TestDatasetRetrievalAdditionalHelpers:
)
assert mapping == {"d1": ["doc-1"]}
assert condition is not None
assert condition.conditions[0].value == "Alice"
assert condition.conditions
first_condition = condition.conditions[0]
assert first_condition.value == "Alice"
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
with pytest.raises(ValueError, match="Invalid metadata filtering mode"):
@ -4666,7 +4686,7 @@ class TestSingleAndMultipleRetrieveCoverage:
return DatasetRetrieval()
def test_single_retrieve_external_path(self, retrieval: DatasetRetrieval) -> None:
dataset = SimpleNamespace(
dataset = _dataset(
id="ds-1",
name="External DS",
description=None,
@ -4711,7 +4731,7 @@ class TestSingleAndMultipleRetrieveCoverage:
assert retrieval.llm_usage.total_tokens == 2
def test_single_retrieve_dify_path_and_filters(self, retrieval: DatasetRetrieval) -> None:
dataset = SimpleNamespace(
dataset = _dataset(
id="ds-1",
name="Internal DS",
description="dataset desc",
@ -4755,7 +4775,7 @@ class TestSingleAndMultipleRetrieveCoverage:
model_config=Mock(),
planning_strategy=PlanningStrategy.ROUTER,
metadata_filter_document_ids={"ds-1": ["doc-1"]},
metadata_condition=SimpleNamespace(),
metadata_condition=_metadata_condition(),
)
assert results == [result_doc]
@ -4772,7 +4792,7 @@ class TestSingleAndMultipleRetrieveCoverage:
user_from="workflow",
query="python",
available_datasets=[
SimpleNamespace(id="ds-1", name="DS", description=None),
_dataset(id="ds-1", name="DS", description=None),
],
model_instance=Mock(),
model_config=Mock(),
@ -4781,7 +4801,7 @@ class TestSingleAndMultipleRetrieveCoverage:
assert results == []
def test_single_retrieve_respects_metadata_filter_shortcuts(self, retrieval: DatasetRetrieval) -> None:
dataset = SimpleNamespace(
dataset = _dataset(
id="ds-1",
name="Internal DS",
description="desc",
@ -4806,7 +4826,7 @@ class TestSingleAndMultipleRetrieveCoverage:
model_config=Mock(),
planning_strategy=PlanningStrategy.REACT_ROUTER,
metadata_filter_document_ids=None,
metadata_condition=SimpleNamespace(),
metadata_condition=_metadata_condition(),
)
missing_doc_ids = retrieval.single_retrieve(
app_id="app-1",
@ -4841,8 +4861,8 @@ class TestSingleAndMultipleRetrieveCoverage:
)
mixed = [
SimpleNamespace(id="d1", indexing_technique="high_quality"),
SimpleNamespace(id="d2", indexing_technique="economy"),
_dataset(id="d1", indexing_technique="high_quality"),
_dataset(id="d2", indexing_technique="economy"),
]
with pytest.raises(ValueError, match="different indexing technique"):
retrieval.multiple_retrieve(
@ -4859,13 +4879,13 @@ class TestSingleAndMultipleRetrieveCoverage:
)
high_quality_mismatch = [
SimpleNamespace(
_dataset(
id="d1",
indexing_technique="high_quality",
embedding_model="model-a",
embedding_model_provider="provider-a",
),
SimpleNamespace(
_dataset(
id="d2",
indexing_technique="high_quality",
embedding_model="model-b",
@ -4888,13 +4908,13 @@ class TestSingleAndMultipleRetrieveCoverage:
def test_multiple_retrieve_threads_and_dedup(self, retrieval: DatasetRetrieval) -> None:
datasets = [
SimpleNamespace(
_dataset(
id="d1",
indexing_technique="high_quality",
embedding_model="model-a",
embedding_model_provider="provider-a",
),
SimpleNamespace(
_dataset(
id="d2",
indexing_technique="high_quality",
embedding_model="model-a",
@ -4956,7 +4976,7 @@ class TestSingleAndMultipleRetrieveCoverage:
def test_multiple_retrieve_propagates_thread_exception(self, retrieval: DatasetRetrieval) -> None:
datasets = [
SimpleNamespace(
_dataset(
id="d1",
indexing_technique="high_quality",
embedding_model="model-a",