mirror of
https://github.com/langgenius/dify.git
synced 2026-06-12 19:53:38 +08:00
chore(api): Fix several typing errors (#37237)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
99351d2f98
commit
b61d39ae2b
@ -24,14 +24,19 @@ from models.model import Message
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasedGenerateTaskPipeline:
|
||||
class BasedGenerateTaskPipeline[AppGenerateEntityT: AppGenerateEntity]:
|
||||
"""
|
||||
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
|
||||
The type parameter preserves the concrete application generate entity for
|
||||
subclasses after the shared initializer stores it on ``_application_generate_entity``.
|
||||
"""
|
||||
|
||||
_application_generate_entity: AppGenerateEntityT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AppGenerateEntity,
|
||||
application_generate_entity: AppGenerateEntityT,
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
):
|
||||
|
||||
@ -65,19 +65,20 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought, Me
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
type EasyUIAppGenerateEntity = ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity
|
||||
|
||||
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline[EasyUIAppGenerateEntity]):
|
||||
"""
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity
|
||||
_precomputed_event_type: StreamEvent | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity,
|
||||
application_generate_entity: EasyUIAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
@ -310,12 +311,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
yield response
|
||||
case QueueLLMChunkEvent() | QueueAgentMessageEvent():
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
delta_content = chunk.delta.message.content
|
||||
if delta_content is None:
|
||||
continue
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
if isinstance(delta_content, list):
|
||||
# EasyUI streams text only; structured multimodal chunks contribute their text parts.
|
||||
delta_text = ""
|
||||
for content in chunk.delta.message.content:
|
||||
for content in delta_content:
|
||||
logger.debug(
|
||||
"The content type %s in LLM chunk delta message content.: %r", type(content), content
|
||||
)
|
||||
@ -331,17 +333,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
content,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
delta_text = delta_content
|
||||
|
||||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
|
||||
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
current_content = cast(str, self._task_state.llm_result.message.content)
|
||||
current_content += cast(str, delta_text)
|
||||
current_content += delta_text
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
match event:
|
||||
@ -352,13 +356,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message_id=self._message_id
|
||||
)
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
answer=delta_text,
|
||||
message_id=self._message_id,
|
||||
event_type=self._precomputed_event_type,
|
||||
)
|
||||
case _:
|
||||
yield self._agent_message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
answer=delta_text,
|
||||
message_id=self._message_id,
|
||||
)
|
||||
case QueueMessageReplaceEvent():
|
||||
@ -389,9 +393,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if not conversation:
|
||||
raise ValueError(f"Conversation {self._conversation_id} not found")
|
||||
|
||||
message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
saved_prompt = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
self._model_config.mode, self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
object.__setattr__(message, "message", saved_prompt)
|
||||
message.message_tokens = usage.prompt_tokens
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
message.message_price_unit = usage.prompt_price_unit
|
||||
|
||||
@ -107,7 +107,7 @@ class LLMGenerator:
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
prompts = [UserPromptMessage(content=prompt)]
|
||||
prompts: list[PromptMessage] = [UserPromptMessage(content=prompt)]
|
||||
|
||||
with measure_time() as timer:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
@ -201,11 +201,13 @@ class LLMGenerator:
|
||||
except InvokeAuthorizationError:
|
||||
return []
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
prompt_messages: list[PromptMessage] = [UserPromptMessage(content=prompt)]
|
||||
|
||||
questions: Sequence[str] = []
|
||||
|
||||
try:
|
||||
model_parameters: dict[str, object]
|
||||
stop: list[str]
|
||||
configured_completion_params = configured_model.get("completion_params")
|
||||
if use_configured_model and isinstance(configured_completion_params, dict):
|
||||
model_parameters, stop = _normalize_completion_params(configured_completion_params)
|
||||
@ -253,7 +255,7 @@ class LLMGenerator:
|
||||
remove_template_variables=False,
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt_generate)]
|
||||
no_variable_prompt_messages: list[PromptMessage] = [UserPromptMessage(content=prompt_generate)]
|
||||
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
|
||||
@ -266,7 +268,7 @@ class LLMGenerator:
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
prompt_messages=list(no_variable_prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
rule_config["prompt"] = response.message.get_text_content()
|
||||
@ -299,7 +301,7 @@ class LLMGenerator:
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)]
|
||||
prompt_generate_messages: list[PromptMessage] = [UserPromptMessage(content=prompt_generate_prompt)]
|
||||
|
||||
# get model instance
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
@ -314,7 +316,7 @@ class LLMGenerator:
|
||||
try:
|
||||
# the first step to generate the task prompt
|
||||
prompt_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
prompt_messages=list(prompt_generate_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
@ -331,7 +333,7 @@ class LLMGenerator:
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
|
||||
parameter_messages: list[PromptMessage] = [UserPromptMessage(content=parameter_generate_prompt)]
|
||||
|
||||
# the second step to generate the task_parameter and task_statement
|
||||
statement_generate_prompt = statement_template.format(
|
||||
@ -341,7 +343,7 @@ class LLMGenerator:
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
statement_messages: list[PromptMessage] = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
parameter_content: LLMResult = model_instance.invoke_llm(
|
||||
@ -397,7 +399,7 @@ class LLMGenerator:
|
||||
model=args.model_config_data.name,
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
prompt_messages: list[PromptMessage] = [UserPromptMessage(content=prompt)]
|
||||
model_parameters = args.model_config_data.completion_params
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
@ -455,7 +457,7 @@ class LLMGenerator:
|
||||
model=args.model_config_data.name,
|
||||
)
|
||||
|
||||
prompt_messages = [
|
||||
prompt_messages: list[PromptMessage] = [
|
||||
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
|
||||
UserPromptMessage(content=args.instruction),
|
||||
]
|
||||
@ -634,7 +636,7 @@ class LLMGenerator:
|
||||
system_prompt = LLM_MODIFY_CODE_SYSTEM
|
||||
case _:
|
||||
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
|
||||
prompt_messages = [
|
||||
prompt_messages: list[PromptMessage] = [
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
UserPromptMessage(
|
||||
content=json.dumps(
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from typing import NotRequired, TypedDict, cast
|
||||
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from graphon.model_runtime.entities import (
|
||||
@ -13,19 +13,46 @@ from graphon.model_runtime.entities import (
|
||||
)
|
||||
|
||||
|
||||
class SavedPromptFile(TypedDict):
|
||||
type: str
|
||||
data: str
|
||||
detail: NotRequired[str]
|
||||
format: NotRequired[str]
|
||||
|
||||
|
||||
class SavedPromptToolCallFunction(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class SavedPromptToolCall(TypedDict):
|
||||
id: str
|
||||
type: str
|
||||
function: SavedPromptToolCallFunction
|
||||
|
||||
|
||||
class SavedPrompt(TypedDict):
|
||||
role: str
|
||||
text: str
|
||||
files: NotRequired[list[SavedPromptFile]]
|
||||
tool_calls: NotRequired[list[SavedPromptToolCall]]
|
||||
|
||||
|
||||
class PromptMessageUtil:
|
||||
@staticmethod
|
||||
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]):
|
||||
def prompt_messages_to_prompt_for_saving(
|
||||
model_mode: str, prompt_messages: Sequence[PromptMessage]
|
||||
) -> list[SavedPrompt]:
|
||||
"""
|
||||
Prompt messages to prompt for saving.
|
||||
:param model_mode: model mode
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
prompts = []
|
||||
prompts: list[SavedPrompt] = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
tool_calls = []
|
||||
for prompt_message in prompt_messages:
|
||||
tool_calls: list[SavedPromptToolCall] = []
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
role = "user"
|
||||
elif prompt_message.role == PromptMessageRole.ASSISTANT:
|
||||
@ -50,7 +77,7 @@ class PromptMessageUtil:
|
||||
continue
|
||||
|
||||
text = ""
|
||||
files = []
|
||||
files: list[SavedPromptFile] = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
match content:
|
||||
@ -77,7 +104,7 @@ class PromptMessageUtil:
|
||||
else:
|
||||
text = cast(str, prompt_message.content)
|
||||
|
||||
prompt = {"role": role, "text": text, "files": files}
|
||||
prompt: SavedPrompt = {"role": role, "text": text, "files": files}
|
||||
|
||||
if tool_calls:
|
||||
prompt["tool_calls"] = tool_calls
|
||||
@ -86,14 +113,14 @@ class PromptMessageUtil:
|
||||
else:
|
||||
prompt_message = prompt_messages[0]
|
||||
text = ""
|
||||
files = []
|
||||
prompt_files: list[SavedPromptFile] = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
files.append(
|
||||
prompt_files.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
@ -103,13 +130,13 @@ class PromptMessageUtil:
|
||||
else:
|
||||
text = cast(str, prompt_message.content)
|
||||
|
||||
params: dict[str, Any] = {
|
||||
params: SavedPrompt = {
|
||||
"role": "user",
|
||||
"text": text,
|
||||
}
|
||||
|
||||
if files:
|
||||
params["files"] = files
|
||||
if prompt_files:
|
||||
params["files"] = prompt_files
|
||||
|
||||
prompts.append(params)
|
||||
|
||||
|
||||
@ -105,6 +105,8 @@ class WaterCrawlProvider:
|
||||
|
||||
def scrape_url(self, url: str) -> WatercrawlDocumentData:
|
||||
response = self.client.scrape_url(url=url, sync=True, prefetched=True)
|
||||
if not isinstance(response, dict):
|
||||
raise ValueError("Invalid scrape response. Expected a JSON dictionary.")
|
||||
return self._structure_data(response)
|
||||
|
||||
def _structure_data(self, result_object: dict[str, Any]) -> WatercrawlDocumentData:
|
||||
|
||||
@ -15,6 +15,8 @@ from urllib.parse import urlparse
|
||||
|
||||
from docx import Document as DocxDocument
|
||||
from docx.oxml.ns import qn
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
from docx.text.run import Run
|
||||
|
||||
from configs import dify_config
|
||||
@ -286,10 +288,10 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
return "".join(paragraph_content).strip()
|
||||
|
||||
def parse_docx(self, docx_path):
|
||||
def parse_docx(self, docx_path: str) -> str:
|
||||
doc = DocxDocument(docx_path)
|
||||
|
||||
content = []
|
||||
content: list[str] = []
|
||||
|
||||
image_map = self._extract_images_from_docx(doc)
|
||||
|
||||
@ -445,18 +447,11 @@ class WordExtractor(BaseExtractor):
|
||||
process_hyperlink(child, paragraph_content)
|
||||
return "".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
paragraphs = doc.paragraphs.copy()
|
||||
tables = doc.tables.copy()
|
||||
for element in doc.element.body:
|
||||
if hasattr(element, "tag"):
|
||||
if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph
|
||||
para = paragraphs.pop(0)
|
||||
parsed_paragraph = parse_paragraph(para)
|
||||
if parsed_paragraph.strip():
|
||||
content.append(parsed_paragraph)
|
||||
else:
|
||||
content.append("\n")
|
||||
elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table
|
||||
table = tables.pop(0)
|
||||
content.append(self._table_to_markdown(table, image_map))
|
||||
for block in doc.iter_inner_content():
|
||||
match block:
|
||||
case Paragraph():
|
||||
parsed_paragraph = parse_paragraph(block)
|
||||
content.append(parsed_paragraph if parsed_paragraph.strip() else "\n")
|
||||
case Table():
|
||||
content.append(self._table_to_markdown(block, image_map))
|
||||
return "\n".join(content)
|
||||
|
||||
@ -16,11 +16,9 @@ from sqlalchemy import select
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.file import remote_fetcher
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.rag.splitter.fixed_text_splitter import (
|
||||
EnhanceRecursiveCharacterTextSplitter,
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
@ -99,18 +97,6 @@ class BaseIndexProcessor(ABC):
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_splitter(
|
||||
self,
|
||||
processing_rule_mode: str,
|
||||
|
||||
@ -16,9 +16,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.entities import Rule
|
||||
@ -28,7 +26,6 @@ from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from extensions.ext_database import db
|
||||
@ -182,35 +179,6 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
keyword.delete()
|
||||
|
||||
@override
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_method,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@override
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
documents: list[Any] = []
|
||||
|
||||
@ -12,8 +12,6 @@ from core.db.session_factory import session_factory
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.entities import ParentMode, Rule
|
||||
@ -23,7 +21,6 @@ from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models import Account
|
||||
@ -223,35 +220,6 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
@override
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_method,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def _split_child_nodes(
|
||||
self,
|
||||
document_node: Document,
|
||||
|
||||
@ -15,8 +15,6 @@ from core.db.session_factory import session_factory
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.entities import Rule
|
||||
@ -25,7 +23,6 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.account import Account
|
||||
@ -187,35 +184,6 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
vector.delete()
|
||||
|
||||
@override
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: RerankingModelDict,
|
||||
):
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_method,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@override
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
|
||||
@ -609,7 +609,7 @@ class DatasetRetrieval:
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||
metadata_condition: MetadataFilteringCondition | None = None,
|
||||
):
|
||||
tools = []
|
||||
tools: list[PromptMessageTool] = []
|
||||
for dataset in available_datasets:
|
||||
description = dataset.description
|
||||
if not description:
|
||||
@ -1162,7 +1162,7 @@ class DatasetRetrieval:
|
||||
:param invoke_from: invoke from
|
||||
:param hit_callback: hit callback
|
||||
"""
|
||||
tools = []
|
||||
tools: list[DatasetRetrieverBaseTool] = []
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
|
||||
@ -30,6 +30,11 @@ class CelerySSLOptionsDict(TypedDict):
|
||||
ssl_keyfile: str | None
|
||||
|
||||
|
||||
class CeleryBeatScheduleEntry(TypedDict):
|
||||
task: str
|
||||
schedule: crontab | timedelta
|
||||
|
||||
|
||||
def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
|
||||
"""Get SSL configuration for Celery broker/backend connections."""
|
||||
# Only apply SSL if we're using Redis as broker/backend
|
||||
@ -152,7 +157,7 @@ def init_app(app: DifyApp) -> Celery:
|
||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||
|
||||
# if you add a new task, please add the switch to CeleryScheduleTasksConfig
|
||||
beat_schedule = {}
|
||||
beat_schedule: dict[str, CeleryBeatScheduleEntry] = {}
|
||||
if dify_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK:
|
||||
imports.append("schedule.clean_embedding_cache_task")
|
||||
beat_schedule["clean_embedding_cache_task"] = {
|
||||
|
||||
@ -4,7 +4,7 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, cast, override
|
||||
|
||||
import mlflow
|
||||
from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType
|
||||
from mlflow.entities import Document, LiveSpan, Span, SpanEvent, SpanStatusCode, SpanType
|
||||
from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey
|
||||
from mlflow.tracing.fluent import start_span_no_context, update_current_trace
|
||||
from mlflow.tracing.provider import detach_span_from_context, set_span_in_context
|
||||
@ -31,6 +31,8 @@ from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
type SpanAttributes = dict[str, object]
|
||||
|
||||
|
||||
def datetime_to_nanoseconds(dt: datetime | None) -> int | None:
|
||||
"""Convert datetime to nanosecond timestamp for MLflow API"""
|
||||
@ -39,6 +41,32 @@ def datetime_to_nanoseconds(dt: datetime | None) -> int | None:
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def _start_span_no_context(
|
||||
*,
|
||||
name: str,
|
||||
span_type: str,
|
||||
parent_span: LiveSpan | None = None,
|
||||
inputs: object | None = None,
|
||||
attributes: SpanAttributes | None = None,
|
||||
start_time_ns: int | None = None,
|
||||
) -> LiveSpan:
|
||||
"""Start an MLflow span while preserving structured Dify attributes.
|
||||
|
||||
MLflow 3.11 annotates `start_span_no_context(..., attributes=...)` as `dict[str, str]`,
|
||||
but the implementation immediately calls `LiveSpan.set_attributes(dict[str, Any])`.
|
||||
`LiveSpan` JSON-serializes arbitrary values before storing them in OpenTelemetry, and
|
||||
reserved attributes like `mlflow.chat.tokenUsage` are expected to round-trip as dicts.
|
||||
"""
|
||||
return start_span_no_context(
|
||||
name=name,
|
||||
span_type=span_type,
|
||||
parent_span=parent_span,
|
||||
inputs=inputs,
|
||||
attributes=cast(dict[str, str] | None, attributes),
|
||||
start_time_ns=start_time_ns,
|
||||
)
|
||||
|
||||
|
||||
class MLflowDataTrace(BaseTraceInstance):
|
||||
def __init__(self, config: MLflowConfig | DatabricksConfig):
|
||||
super().__init__(config)
|
||||
@ -119,7 +147,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
if trace_info.query:
|
||||
workflow_inputs["query"] = trace_info.query
|
||||
|
||||
workflow_span = start_span_no_context(
|
||||
workflow_span = _start_span_no_context(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
span_type=SpanType.CHAIN,
|
||||
inputs=workflow_inputs,
|
||||
@ -139,7 +167,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
# Create child spans for workflow nodes
|
||||
for node in self._get_workflow_nodes(trace_info.workflow_run_id):
|
||||
inputs = None
|
||||
attributes = {
|
||||
attributes: SpanAttributes = {
|
||||
"node_id": node.id,
|
||||
"node_type": node.node_type,
|
||||
"status": node.status,
|
||||
@ -157,7 +185,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
if not inputs:
|
||||
inputs = JSON_DICT_ADAPTER.validate_json(node.inputs) if node.inputs else {}
|
||||
|
||||
node_span = start_span_no_context(
|
||||
node_span = _start_span_no_context(
|
||||
name=node.title,
|
||||
span_type=self._get_node_span_type(node.node_type),
|
||||
parent_span=workflow_span,
|
||||
@ -212,7 +240,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
|
||||
)
|
||||
|
||||
def _parse_llm_inputs_and_attributes(self, node: WorkflowNodeExecutionModel) -> tuple[Any, dict]:
|
||||
def _parse_llm_inputs_and_attributes(self, node: WorkflowNodeExecutionModel) -> tuple[object, SpanAttributes]:
|
||||
"""Parse LLM inputs and attributes from LLM workflow node"""
|
||||
if node.process_data is None:
|
||||
return {}, {}
|
||||
@ -266,16 +294,16 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
file_list.append(f"{base_url}/{message_file_data.url}")
|
||||
|
||||
span = start_span_no_context(
|
||||
span = _start_span_no_context(
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
span_type=SpanType.LLM,
|
||||
inputs=self._parse_prompts(trace_info.inputs), # type: ignore[arg-type]
|
||||
inputs=self._parse_prompts(trace_info.inputs),
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"message_id": trace_info.message_id,
|
||||
"model_provider": trace_info.message_data.model_provider,
|
||||
"model_id": trace_info.message_data.model_id,
|
||||
"conversation_mode": trace_info.conversation_mode,
|
||||
"file_list": file_list, # type: ignore[dict-item]
|
||||
"file_list": file_list,
|
||||
"total_price": trace_info.message_data.total_price,
|
||||
**trace_info.metadata,
|
||||
},
|
||||
@ -330,15 +358,15 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
return metadata.get("from_account_id") # type: ignore[return-value]
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
span = start_span_no_context(
|
||||
span = _start_span_no_context(
|
||||
name=trace_info.tool_name,
|
||||
span_type=SpanType.TOOL,
|
||||
inputs=trace_info.tool_inputs, # type: ignore[arg-type]
|
||||
inputs=trace_info.tool_inputs,
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"metadata": trace_info.metadata, # type: ignore[dict-item]
|
||||
"tool_config": trace_info.tool_config, # type: ignore[dict-item]
|
||||
"tool_parameters": trace_info.tool_parameters, # type: ignore[dict-item]
|
||||
"message_id": trace_info.message_id,
|
||||
"metadata": trace_info.metadata,
|
||||
"tool_config": trace_info.tool_config,
|
||||
"tool_parameters": trace_info.tool_parameters,
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
|
||||
)
|
||||
@ -367,13 +395,13 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
span = start_span_no_context(
|
||||
span = _start_span_no_context(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
span_type=SpanType.TOOL,
|
||||
inputs=trace_info.inputs or {},
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"metadata": trace_info.metadata, # type: ignore[dict-item]
|
||||
"message_id": trace_info.message_id,
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(start_time),
|
||||
)
|
||||
@ -391,13 +419,13 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
span = start_span_no_context(
|
||||
span = _start_span_no_context(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
span_type=SpanType.RETRIEVER,
|
||||
inputs=trace_info.inputs,
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"metadata": trace_info.metadata, # type: ignore[dict-item]
|
||||
"message_id": trace_info.message_id,
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
|
||||
)
|
||||
@ -410,15 +438,15 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
span = start_span_no_context(
|
||||
span = _start_span_no_context(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
span_type=SpanType.TOOL,
|
||||
inputs=trace_info.inputs,
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"model_provider": trace_info.model_provider, # type: ignore[dict-item]
|
||||
"model_id": trace_info.model_id, # type: ignore[dict-item]
|
||||
"total_tokens": trace_info.total_tokens or 0, # type: ignore[dict-item]
|
||||
"message_id": trace_info.message_id,
|
||||
"model_provider": trace_info.model_provider,
|
||||
"model_id": trace_info.model_id,
|
||||
"total_tokens": trace_info.total_tokens or 0,
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(start_time),
|
||||
)
|
||||
@ -439,11 +467,11 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
span.end(outputs=trace_info.suggested_question, end_time_ns=datetime_to_nanoseconds(end_time))
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
span = start_span_no_context(
|
||||
span = _start_span_no_context(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
span_type=SpanType.CHAIN,
|
||||
inputs=trace_info.inputs,
|
||||
attributes={"message_id": trace_info.message_id}, # type: ignore[dict-item]
|
||||
attributes={"message_id": trace_info.message_id},
|
||||
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
|
||||
)
|
||||
span.end(outputs=trace_info.outputs, end_time_ns=datetime_to_nanoseconds(trace_info.end_time))
|
||||
|
||||
@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig
|
||||
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds
|
||||
from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey
|
||||
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
@ -361,6 +362,7 @@ class TestWorkflowTrace:
|
||||
assert inputs["query"] == "hello"
|
||||
|
||||
def test_workflow_with_llm_node(self, trace_instance, mock_tracing, mock_db):
|
||||
usage = {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
||||
llm_node = _make_node(
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
process_data=json.dumps(
|
||||
@ -369,7 +371,7 @@ class TestWorkflowTrace:
|
||||
"model_name": "gpt-4",
|
||||
"model_provider": "openai",
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15},
|
||||
"usage": usage,
|
||||
}
|
||||
),
|
||||
outputs='{"text": "hello world"}',
|
||||
@ -383,6 +385,14 @@ class TestWorkflowTrace:
|
||||
|
||||
trace_instance.workflow_trace(_make_workflow_trace_info())
|
||||
assert mock_tracing["start"].call_count == 2
|
||||
node_start_call = mock_tracing["start"].call_args_list[1]
|
||||
attrs = node_start_call.kwargs["attributes"]
|
||||
assert attrs[SpanAttributeKey.CHAT_USAGE] == {
|
||||
TokenUsageKey.INPUT_TOKENS: 5,
|
||||
TokenUsageKey.OUTPUT_TOKENS: 10,
|
||||
TokenUsageKey.TOTAL_TOKENS: 15,
|
||||
}
|
||||
assert attrs["usage"] == usage
|
||||
node_span.end.assert_called_once()
|
||||
workflow_span.end.assert_called_once()
|
||||
|
||||
@ -631,6 +641,27 @@ class TestMessageTrace:
|
||||
assert "http://files.test/path/to/file.png" in attrs["file_list"]
|
||||
assert "existing_file.txt" in attrs["file_list"]
|
||||
|
||||
def test_message_trace_preserves_structured_span_attributes(self, trace_instance, mock_tracing, mock_db):
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
|
||||
trace_info = _make_message_trace_info(
|
||||
metadata={
|
||||
"conversation_id": "c1",
|
||||
"from_account_id": "a1",
|
||||
"routing": {"node": "answer", "score": 0.7},
|
||||
},
|
||||
file_list=["existing_file.txt"],
|
||||
)
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
attrs = mock_tracing["start"].call_args.kwargs["attributes"]
|
||||
assert attrs["message_id"] == "msg-1"
|
||||
assert attrs["total_price"] == 0.01
|
||||
assert attrs["routing"] == {"node": "answer", "score": 0.7}
|
||||
assert attrs["file_list"] == ["existing_file.txt"]
|
||||
|
||||
def test_message_trace_file_list_none(self, trace_instance, mock_tracing, mock_db):
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
|
||||
@ -1,9 +1,3 @@
|
||||
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
|
||||
core/llm_generator/llm_generator.py
|
||||
providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
|
||||
core/prompt/utils/prompt_message_util.py
|
||||
core/rag/retrieval/dataset_retrieval.py
|
||||
extensions/ext_celery.py
|
||||
providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_factory.py
|
||||
providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py
|
||||
providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_openapi.py
|
||||
@ -59,11 +53,6 @@ providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py
|
||||
providers/vdb/vdb-vikingdb/tests/unit_tests/test_vikingdb_vector.py
|
||||
providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py
|
||||
providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py
|
||||
core/rag/extractor/watercrawl/provider.py
|
||||
core/rag/extractor/word_extractor.py
|
||||
core/rag/index_processor/processor/paragraph_index_processor.py
|
||||
core/rag/index_processor/processor/parent_child_index_processor.py
|
||||
core/rag/index_processor/processor/qa_index_processor.py
|
||||
core/tools/mcp_tool/provider.py
|
||||
core/tools/plugin_tool/provider.py
|
||||
core/tools/workflow_as_tool/provider.py
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -4,9 +4,10 @@ import io
|
||||
import os
|
||||
import tempfile
|
||||
from collections import UserDict
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import override
|
||||
from typing import Protocol, cast, override
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -18,6 +19,14 @@ import core.rag.extractor.word_extractor as we
|
||||
from core.rag.extractor.word_extractor import WordExtractor
|
||||
|
||||
|
||||
class _TextOxmlElement(Protocol):
|
||||
text: str | None
|
||||
|
||||
|
||||
def _set_oxml_text(element: object, text: str) -> None:
|
||||
cast(_TextOxmlElement, element).text = text
|
||||
|
||||
|
||||
def _generate_table_with_merged_cells():
|
||||
doc = Document()
|
||||
|
||||
@ -190,8 +199,8 @@ def test_extract_images_from_docx_uses_internal_files_url():
|
||||
from configs import dify_config
|
||||
|
||||
# Mock the configuration values
|
||||
original_files_url = getattr(dify_config, "FILES_URL", None)
|
||||
original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None)
|
||||
original_files_url = dify_config.FILES_URL
|
||||
original_internal_files_url = dify_config.INTERNAL_FILES_URL
|
||||
|
||||
try:
|
||||
# Set both URLs - INTERNAL should take precedence
|
||||
@ -233,7 +242,7 @@ def test_extract_hyperlinks(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
new_run = OxmlElement("w:r")
|
||||
t = OxmlElement("w:t")
|
||||
t.text = "Dify"
|
||||
_set_oxml_text(t, "Dify")
|
||||
new_run.append(t)
|
||||
hyperlink.append(new_run)
|
||||
p._p.append(hyperlink)
|
||||
@ -286,7 +295,7 @@ def test_extract_legacy_hyperlinks(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
run2 = OxmlElement("w:r")
|
||||
instrText = OxmlElement("w:instrText")
|
||||
instrText.text = ' HYPERLINK "http://example.com" '
|
||||
_set_oxml_text(instrText, ' HYPERLINK "http://example.com" ')
|
||||
run2.append(instrText)
|
||||
p._p.append(run2)
|
||||
|
||||
@ -298,7 +307,7 @@ def test_extract_legacy_hyperlinks(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
run4 = OxmlElement("w:r")
|
||||
t4 = OxmlElement("w:t")
|
||||
t4.text = "Example"
|
||||
_set_oxml_text(t4, "Example")
|
||||
run4.append(t4)
|
||||
p._p.append(run4)
|
||||
|
||||
@ -380,20 +389,27 @@ def test_close_is_idempotent():
|
||||
extractor.temp_file.close.assert_called_once()
|
||||
|
||||
|
||||
async def _async_close() -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_close_closes_awaitable_close_result():
|
||||
class FakeAwaitable:
|
||||
closed: bool = False
|
||||
|
||||
def __await__(self) -> Generator[None, None, None]:
|
||||
if False:
|
||||
yield None
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
extractor = object.__new__(WordExtractor)
|
||||
extractor._closed = False
|
||||
extractor.temp_file = MagicMock()
|
||||
close_result = _async_close()
|
||||
close_result = FakeAwaitable()
|
||||
extractor.temp_file.close = MagicMock(return_value=close_result)
|
||||
|
||||
extractor.close()
|
||||
|
||||
assert close_result.cr_frame is None
|
||||
assert close_result.closed is True
|
||||
extractor.temp_file.close.assert_called_once()
|
||||
|
||||
|
||||
@ -506,6 +522,32 @@ def test_table_to_markdown_and_parse_helpers(monkeypatch: pytest.MonkeyPatch):
|
||||
assert extractor._parse_cell(cell, image_map) == "EXT-IMGINT-IMGplain"
|
||||
|
||||
|
||||
def test_parse_docx_reads_real_paragraph_table_order(monkeypatch: pytest.MonkeyPatch):
|
||||
doc = Document()
|
||||
doc.add_paragraph("Before table")
|
||||
table = doc.add_table(rows=2, cols=2)
|
||||
table.cell(0, 0).text = "Header A"
|
||||
table.cell(0, 1).text = "Header B"
|
||||
table.cell(1, 0).text = "Cell A"
|
||||
table.cell(1, 1).text = "Cell B"
|
||||
doc.add_paragraph("After table")
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
|
||||
doc.save(tmp.name)
|
||||
tmp_path = tmp.name
|
||||
|
||||
extractor = object.__new__(WordExtractor)
|
||||
monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: {})
|
||||
|
||||
try:
|
||||
assert extractor.parse_docx(tmp_path) == (
|
||||
"Before table\n| Header A | Header B |\n| --- | --- |\n| Cell A | Cell B |\nAfter table"
|
||||
)
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monkeypatch: pytest.MonkeyPatch):
|
||||
extractor = object.__new__(WordExtractor)
|
||||
|
||||
@ -620,8 +662,15 @@ def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monke
|
||||
self.element = element
|
||||
self.text = getattr(element, "text", "")
|
||||
|
||||
paragraph_main = SimpleNamespace(
|
||||
_element=[
|
||||
class FakeParagraph:
|
||||
def __init__(self, children):
|
||||
self._element = children
|
||||
|
||||
class FakeTable:
|
||||
rows: list[object] = []
|
||||
|
||||
paragraph_main = FakeParagraph(
|
||||
[
|
||||
FakeChild(
|
||||
qn("w:r"),
|
||||
text="run-text",
|
||||
@ -646,17 +695,16 @@ def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monke
|
||||
),
|
||||
]
|
||||
)
|
||||
paragraph_empty = SimpleNamespace(_element=[FakeChild(qn("w:r"), text=" ")])
|
||||
paragraph_empty = FakeParagraph([FakeChild(qn("w:r"), text=" ")])
|
||||
table = FakeTable()
|
||||
|
||||
fake_doc = SimpleNamespace(
|
||||
part=SimpleNamespace(rels=rels, related_parts={int_embed_id: internal_part}),
|
||||
paragraphs=[paragraph_main, paragraph_empty],
|
||||
tables=[SimpleNamespace(rows=[])],
|
||||
element=SimpleNamespace(
|
||||
body=[SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:tbl")]
|
||||
),
|
||||
iter_inner_content=lambda: iter([paragraph_main, paragraph_empty, table]),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(we, "Paragraph", FakeParagraph)
|
||||
monkeypatch.setattr(we, "Table", FakeTable)
|
||||
monkeypatch.setattr(we, "DocxDocument", lambda _: fake_doc)
|
||||
monkeypatch.setattr(we, "Run", FakeRun)
|
||||
monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: image_map)
|
||||
@ -688,7 +736,7 @@ def test_parse_cell_paragraph_hyperlink_in_table_cell_http():
|
||||
|
||||
run_elem = OxmlElement("w:r")
|
||||
t = OxmlElement("w:t")
|
||||
t.text = "Dify"
|
||||
_set_oxml_text(t, "Dify")
|
||||
run_elem.append(t)
|
||||
hyperlink.append(run_elem)
|
||||
p._p.append(hyperlink)
|
||||
@ -728,7 +776,7 @@ def test_parse_cell_paragraph_hyperlink_in_table_cell_mailto():
|
||||
|
||||
run_elem = OxmlElement("w:r")
|
||||
t = OxmlElement("w:t")
|
||||
t.text = "john@test.com"
|
||||
_set_oxml_text(t, "john@test.com")
|
||||
run_elem.append(t)
|
||||
hyperlink.append(run_elem)
|
||||
p._p.append(hyperlink)
|
||||
|
||||
@ -234,20 +234,6 @@ class TestParagraphIndexProcessor:
|
||||
|
||||
mock_keyword_cls.return_value.delete_by_ids.assert_called_once_with(["node-2"])
|
||||
|
||||
def test_retrieve_filters_by_threshold(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
|
||||
accepted = SimpleNamespace(page_content="keep", metadata={"source": "a"}, score=0.9)
|
||||
rejected = SimpleNamespace(page_content="drop", metadata={"source": "b"}, score=0.1)
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve"
|
||||
) as mock_retrieve:
|
||||
mock_retrieve.return_value = [accepted, rejected]
|
||||
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
|
||||
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["score"] == 0.9
|
||||
|
||||
def test_index_list_chunks_high_quality(
|
||||
self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock
|
||||
) -> None:
|
||||
|
||||
@ -4,7 +4,7 @@ from unittest.mock import MagicMock, Mock, patch
|
||||
import pytest
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.rag.entities import ParentMode
|
||||
from core.rag.entities import ParentMode, Rule, Segmentation
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
@ -293,29 +293,14 @@ class TestParentChildIndexProcessor:
|
||||
|
||||
mock_summary.assert_called_once_with(dataset, None)
|
||||
|
||||
def test_retrieve_filters_by_score_threshold(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
ok_result = SimpleNamespace(page_content="keep", metadata={"m": 1}, score=0.8)
|
||||
low_result = SimpleNamespace(page_content="drop", metadata={"m": 2}, score=0.2)
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve"
|
||||
) as mock_retrieve:
|
||||
mock_retrieve.return_value = [ok_result, low_result]
|
||||
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
|
||||
docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, reranking_model)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "keep"
|
||||
assert docs[0].metadata["score"] == 0.8
|
||||
|
||||
def test_split_child_nodes_requires_subchunk_segmentation(self, processor: ParentChildIndexProcessor) -> None:
|
||||
rules = SimpleNamespace(subchunk_segmentation=None)
|
||||
rules = Rule(subchunk_segmentation=None)
|
||||
|
||||
with pytest.raises(ValueError, match="No subchunk segmentation found"):
|
||||
processor._split_child_nodes(Document(page_content="parent", metadata={}), rules, "custom", None)
|
||||
|
||||
def test_split_child_nodes_generates_child_documents(self, processor: ParentChildIndexProcessor) -> None:
|
||||
rules = SimpleNamespace(subchunk_segmentation=self._segmentation())
|
||||
rules = Rule(subchunk_segmentation=Segmentation(max_tokens=200, chunk_overlap=10, separator="\n"))
|
||||
splitter = Mock()
|
||||
splitter.split_documents.return_value = [
|
||||
Document(page_content=".child-1", metadata={}),
|
||||
|
||||
@ -258,19 +258,6 @@ class TestQAIndexProcessor:
|
||||
mock_summary.assert_called_once_with(dataset, None)
|
||||
vector.delete.assert_called_once()
|
||||
|
||||
def test_retrieve_filters_by_score_threshold(self, processor: QAIndexProcessor, dataset: Mock) -> None:
|
||||
result_ok = SimpleNamespace(page_content="accepted", metadata={"source": "a"}, score=0.9)
|
||||
result_low = SimpleNamespace(page_content="rejected", metadata={"source": "b"}, score=0.1)
|
||||
|
||||
with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve:
|
||||
mock_retrieve.return_value = [result_ok, result_low]
|
||||
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
|
||||
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "accepted"
|
||||
assert docs[0].metadata["score"] == 0.9
|
||||
|
||||
def test_index_adds_documents_and_vectors_for_high_quality(
|
||||
self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock
|
||||
) -> None:
|
||||
@ -331,7 +318,7 @@ class TestQAIndexProcessor:
|
||||
|
||||
def test_generate_summary_preview_returns_input(self, processor: QAIndexProcessor) -> None:
|
||||
preview_items = [PreviewDetail(content="Q1")]
|
||||
assert processor.generate_summary_preview("tenant-1", preview_items, {}) is preview_items
|
||||
assert processor.generate_summary_preview("tenant-1", preview_items, {"enable": False}) is preview_items
|
||||
|
||||
def test_format_qa_document_ignores_blank_text(self, processor: QAIndexProcessor, fake_flask_app) -> None:
|
||||
all_qa_documents: list[Document] = []
|
||||
|
||||
@ -9,7 +9,6 @@ from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class _ForwardingBaseIndexProcessor(BaseIndexProcessor):
|
||||
@ -52,17 +51,6 @@ class _ForwardingBaseIndexProcessor(BaseIndexProcessor):
|
||||
def format_preview(self, chunks):
|
||||
return super().format_preview(chunks)
|
||||
|
||||
@override
|
||||
def retrieve(self, retrieval_method, query, dataset, top_k, score_threshold, reranking_model):
|
||||
return super().retrieve(
|
||||
retrieval_method=retrieval_method,
|
||||
query=query,
|
||||
dataset=dataset,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
)
|
||||
|
||||
|
||||
class TestBaseIndexProcessor:
|
||||
@pytest.fixture
|
||||
@ -75,7 +63,7 @@ class TestBaseIndexProcessor:
|
||||
with pytest.raises(NotImplementedError):
|
||||
processor.transform([])
|
||||
with pytest.raises(NotImplementedError):
|
||||
processor.generate_summary_preview("tenant", [PreviewDetail(content="c")], {})
|
||||
processor.generate_summary_preview("tenant", [PreviewDetail(content="c")], {"enable": False})
|
||||
with pytest.raises(NotImplementedError):
|
||||
processor.load(Mock(), [])
|
||||
with pytest.raises(NotImplementedError):
|
||||
@ -84,8 +72,6 @@ class TestBaseIndexProcessor:
|
||||
processor.index(Mock(), Mock(), {})
|
||||
with pytest.raises(NotImplementedError):
|
||||
processor.format_preview([])
|
||||
with pytest.raises(NotImplementedError):
|
||||
processor.retrieve(RetrievalMethod.SEMANTIC_SEARCH, "q", Mock(), 3, 0.5, {})
|
||||
|
||||
def test_get_splitter_validates_custom_length(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
with patch(
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
@ -86,6 +87,14 @@ def create_mock_document(
|
||||
)
|
||||
|
||||
|
||||
def _dataset(**values: object) -> Dataset:
|
||||
return cast(Dataset, SimpleNamespace(**values))
|
||||
|
||||
|
||||
def _metadata_condition() -> AppMetadataFilteringCondition:
|
||||
return AppMetadataFilteringCondition(logical_operator="and", conditions=[])
|
||||
|
||||
|
||||
def create_side_effect_for_search(documents: list[Document]):
|
||||
"""
|
||||
Create a side effect function for mocking search methods.
|
||||
@ -2101,6 +2110,7 @@ class TestDocumentModel:
|
||||
doc = Document(page_content="Test content", vector=vector)
|
||||
|
||||
assert doc.vector == vector
|
||||
assert doc.vector is not None
|
||||
assert len(doc.vector) == 5
|
||||
|
||||
def test_document_with_external_provider(self):
|
||||
@ -2914,14 +2924,14 @@ class TestProcessMetadataFilterFunc:
|
||||
return mock_string_access
|
||||
elif name in ["year", "price", "rating"]:
|
||||
return mock_float_access
|
||||
elif name == "description":
|
||||
return mock_null_access
|
||||
else:
|
||||
return mock_string_access
|
||||
|
||||
mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect)
|
||||
mock_metadata_field.as_string.return_value = mock_string_access
|
||||
mock_metadata_field.as_float.return_value = mock_float_access
|
||||
mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_
|
||||
mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot
|
||||
|
||||
return mock_metadata_field
|
||||
|
||||
@ -3933,11 +3943,19 @@ class TestDatasetRetrievalAdditionalHelpers:
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
text, returned_usage = retrieval._handle_invoke_result(iter([chunk_1, chunk_2]))
|
||||
|
||||
def _chunks() -> Generator[Any]:
|
||||
yield chunk_1
|
||||
yield chunk_2
|
||||
|
||||
text, returned_usage = retrieval._handle_invoke_result(_chunks())
|
||||
assert text == "hello world"
|
||||
assert returned_usage == usage
|
||||
|
||||
text_empty, usage_empty = retrieval._handle_invoke_result(iter([]))
|
||||
def _empty_chunks() -> Generator[Any]:
|
||||
yield from ()
|
||||
|
||||
text_empty, usage_empty = retrieval._handle_invoke_result(_empty_chunks())
|
||||
assert text_empty == ""
|
||||
assert usage_empty == LLMUsage.empty_usage()
|
||||
|
||||
@ -4176,7 +4194,9 @@ class TestDatasetRetrievalAdditionalHelpers:
|
||||
)
|
||||
assert mapping == {"d1": ["doc-1"]}
|
||||
assert condition is not None
|
||||
assert condition.conditions[0].value == "Alice"
|
||||
assert condition.conditions
|
||||
first_condition = condition.conditions[0]
|
||||
assert first_condition.value == "Alice"
|
||||
|
||||
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
|
||||
with pytest.raises(ValueError, match="Invalid metadata filtering mode"):
|
||||
@ -4666,7 +4686,7 @@ class TestSingleAndMultipleRetrieveCoverage:
|
||||
return DatasetRetrieval()
|
||||
|
||||
def test_single_retrieve_external_path(self, retrieval: DatasetRetrieval) -> None:
|
||||
dataset = SimpleNamespace(
|
||||
dataset = _dataset(
|
||||
id="ds-1",
|
||||
name="External DS",
|
||||
description=None,
|
||||
@ -4711,7 +4731,7 @@ class TestSingleAndMultipleRetrieveCoverage:
|
||||
assert retrieval.llm_usage.total_tokens == 2
|
||||
|
||||
def test_single_retrieve_dify_path_and_filters(self, retrieval: DatasetRetrieval) -> None:
|
||||
dataset = SimpleNamespace(
|
||||
dataset = _dataset(
|
||||
id="ds-1",
|
||||
name="Internal DS",
|
||||
description="dataset desc",
|
||||
@ -4755,7 +4775,7 @@ class TestSingleAndMultipleRetrieveCoverage:
|
||||
model_config=Mock(),
|
||||
planning_strategy=PlanningStrategy.ROUTER,
|
||||
metadata_filter_document_ids={"ds-1": ["doc-1"]},
|
||||
metadata_condition=SimpleNamespace(),
|
||||
metadata_condition=_metadata_condition(),
|
||||
)
|
||||
|
||||
assert results == [result_doc]
|
||||
@ -4772,7 +4792,7 @@ class TestSingleAndMultipleRetrieveCoverage:
|
||||
user_from="workflow",
|
||||
query="python",
|
||||
available_datasets=[
|
||||
SimpleNamespace(id="ds-1", name="DS", description=None),
|
||||
_dataset(id="ds-1", name="DS", description=None),
|
||||
],
|
||||
model_instance=Mock(),
|
||||
model_config=Mock(),
|
||||
@ -4781,7 +4801,7 @@ class TestSingleAndMultipleRetrieveCoverage:
|
||||
assert results == []
|
||||
|
||||
def test_single_retrieve_respects_metadata_filter_shortcuts(self, retrieval: DatasetRetrieval) -> None:
|
||||
dataset = SimpleNamespace(
|
||||
dataset = _dataset(
|
||||
id="ds-1",
|
||||
name="Internal DS",
|
||||
description="desc",
|
||||
@ -4806,7 +4826,7 @@ class TestSingleAndMultipleRetrieveCoverage:
|
||||
model_config=Mock(),
|
||||
planning_strategy=PlanningStrategy.REACT_ROUTER,
|
||||
metadata_filter_document_ids=None,
|
||||
metadata_condition=SimpleNamespace(),
|
||||
metadata_condition=_metadata_condition(),
|
||||
)
|
||||
missing_doc_ids = retrieval.single_retrieve(
|
||||
app_id="app-1",
|
||||
@ -4841,8 +4861,8 @@ class TestSingleAndMultipleRetrieveCoverage:
|
||||
)
|
||||
|
||||
mixed = [
|
||||
SimpleNamespace(id="d1", indexing_technique="high_quality"),
|
||||
SimpleNamespace(id="d2", indexing_technique="economy"),
|
||||
_dataset(id="d1", indexing_technique="high_quality"),
|
||||
_dataset(id="d2", indexing_technique="economy"),
|
||||
]
|
||||
with pytest.raises(ValueError, match="different indexing technique"):
|
||||
retrieval.multiple_retrieve(
|
||||
@ -4859,13 +4879,13 @@ class TestSingleAndMultipleRetrieveCoverage:
|
||||
)
|
||||
|
||||
high_quality_mismatch = [
|
||||
SimpleNamespace(
|
||||
_dataset(
|
||||
id="d1",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model="model-a",
|
||||
embedding_model_provider="provider-a",
|
||||
),
|
||||
SimpleNamespace(
|
||||
_dataset(
|
||||
id="d2",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model="model-b",
|
||||
@ -4888,13 +4908,13 @@ class TestSingleAndMultipleRetrieveCoverage:
|
||||
|
||||
def test_multiple_retrieve_threads_and_dedup(self, retrieval: DatasetRetrieval) -> None:
|
||||
datasets = [
|
||||
SimpleNamespace(
|
||||
_dataset(
|
||||
id="d1",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model="model-a",
|
||||
embedding_model_provider="provider-a",
|
||||
),
|
||||
SimpleNamespace(
|
||||
_dataset(
|
||||
id="d2",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model="model-a",
|
||||
@ -4956,7 +4976,7 @@ class TestSingleAndMultipleRetrieveCoverage:
|
||||
|
||||
def test_multiple_retrieve_propagates_thread_exception(self, retrieval: DatasetRetrieval) -> None:
|
||||
datasets = [
|
||||
SimpleNamespace(
|
||||
_dataset(
|
||||
id="d1",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model="model-a",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user