diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index f5ba498c7d..dada6229db 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -26,6 +26,7 @@ jobs: - name: ast-grep run: | uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all - + - name: mdformat + run: | + uvx mdformat . - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 - diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index d7500c415c..30c890c301 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional import flask_restful from flask_login import current_user @@ -49,7 +49,7 @@ class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Any = None + resource_model: Optional[Any] = None resource_id_field: str | None = None token_prefix: str | None = None max_keys = 10 @@ -102,7 +102,7 @@ class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Any = None + resource_model: Optional[Any] = None resource_id_field: str | None = None def delete(self, resource_id, api_key_id): diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 3e1237615a..2e8e7ae4b2 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -1,3 +1,4 @@ +import contextlib import json import os import time @@ -178,7 +179,7 @@ def cloud_edition_billing_rate_limit_check(resource: str): def cloud_utm_record(view): @wraps(view) def decorated(*args, **kwargs): - try: + with contextlib.suppress(Exception): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: @@ -187,8 +188,7 @@ def cloud_utm_record(view): if utm_info: utm_info_dict: dict = json.loads(utm_info) OperationService.record_utm(current_user.current_tenant_id, utm_info_dict) - except Exception as e: - pass + return view(*args, **kwargs) return decorated diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index ad9b625350..f7c83f927f 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -512,7 +512,6 @@ class BaseAgentRunner(AppRunner): if not file_objs: return UserPromptMessage(content=message.query) prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - prompt_message_contents.append(TextPromptMessageContent(data=message.query)) for file in file_objs: prompt_message_contents.append( file_manager.to_prompt_message_content( @@ -520,4 +519,6 @@ class BaseAgentRunner(AppRunner): image_detail_config=image_detail_config, ) ) + prompt_message_contents.append(TextPromptMessageContent(data=message.query)) + return UserPromptMessage(content=prompt_message_contents) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 5ff89bdacb..4d1d94eadc 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -39,9 +39,6 @@ class CotChatAgentRunner(CotAgentRunner): Organize user query """ if self.files: - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - prompt_message_contents.append(TextPromptMessageContent(data=query)) - # get image detail config image_detail_config = ( self.application_generate_entity.file_upload_config.image_config.detail @@ -52,6 +49,8 @@ class CotChatAgentRunner(CotAgentRunner): else None ) image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] for file in self.files: prompt_message_contents.append( file_manager.to_prompt_message_content( @@ -59,6 +58,7 @@ class CotChatAgentRunner(CotAgentRunner): image_detail_config=image_detail_config, ) ) + prompt_message_contents.append(TextPromptMessageContent(data=query)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 4df71ce9de..4e6fe60e57 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -395,9 +395,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): Organize user query """ if self.files: - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - prompt_message_contents.append(TextPromptMessageContent(data=query)) - # get image detail config image_detail_config = ( self.application_generate_entity.file_upload_config.image_config.detail @@ -408,6 +405,8 @@ class FunctionCallAgentRunner(BaseAgentRunner): else None ) image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] for file in self.files: prompt_message_contents.append( file_manager.to_prompt_message_content( @@ -415,6 +414,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): image_detail_config=image_detail_config, ) ) + prompt_message_contents.append(TextPromptMessageContent(data=query)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 6c2a342289..895ee8581e 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -178,7 +178,7 @@ class ModelConfig(BaseModel): provider: str name: str mode: LLMMode - completion_params: dict[str, Any] = {} + completion_params: dict[str, Any] = Field(default_factory=dict) class Condition(BaseModel): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 42e6a1519c..d663dbb175 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -610,7 +610,7 @@ class QueueErrorEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.ERROR - error: Any = None + error: Optional[Any] = None class QueuePingEvent(AppQueueEvent): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 25c889e922..a1c0368354 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -142,7 +142,7 @@ class MessageEndStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE_END id: str - metadata: dict = {} + metadata: dict = Field(default_factory=dict) files: Optional[Sequence[Mapping[str, Any]]] = None @@ -261,7 +261,7 @@ class NodeStartStreamResponse(StreamResponse): predecessor_node_id: Optional[str] = None inputs: Optional[Mapping[str, Any]] = None created_at: int - extras: dict = {} + extras: dict = Field(default_factory=dict) parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None parent_parallel_id: Optional[str] = None @@ -503,7 +503,7 @@ class IterationNodeStartStreamResponse(StreamResponse): node_type: str title: str created_at: int - extras: dict = {} + extras: dict = Field(default_factory=dict) metadata: Mapping = {} inputs: Mapping = {} parallel_id: Optional[str] = None @@ -531,7 +531,7 @@ class IterationNodeNextStreamResponse(StreamResponse): index: int created_at: int pre_iteration_output: Optional[Any] = None - extras: dict = {} + extras: dict = Field(default_factory=dict) parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None parallel_mode_run_id: Optional[str] = None @@ -590,7 +590,7 @@ class LoopNodeStartStreamResponse(StreamResponse): node_type: str title: str created_at: int - extras: dict = {} + extras: dict = Field(default_factory=dict) metadata: Mapping = {} inputs: Mapping = {} parallel_id: Optional[str] = None @@ -618,7 +618,7 @@ class LoopNodeNextStreamResponse(StreamResponse): index: int created_at: int pre_loop_output: Optional[Any] = None - extras: dict = {} + extras: dict = Field(default_factory=dict) parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None parallel_mode_run_id: Optional[str] = None @@ -764,7 +764,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse): conversation_id: str message_id: str answer: str - metadata: dict = {} + metadata: dict = Field(default_factory=dict) created_at: int data: Data @@ -784,7 +784,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse): mode: str message_id: str answer: str - metadata: dict = {} + metadata: dict = Field(default_factory=dict) created_at: int data: Data diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 014c7fd4f5..8c0a442158 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -52,7 +52,8 @@ class BasedGenerateTaskPipeline: elif isinstance(e, InvokeError | ValueError): err = e else: - err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) + description = getattr(e, "description", None) + err = Exception(description if description is not None else str(e)) if not message_id or not session: return err diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 557f7eb1ed..ae4671a381 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -17,7 +17,7 @@ class ExtensionModule(enum.Enum): class ModuleExtension(BaseModel): - extension_class: Any = None + extension_class: Optional[Any] = None name: str label: Optional[dict] = None form_schema: Optional[list] = None diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 9eb9e0306b..50c3f9b5f4 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -38,6 +38,7 @@ class Extension: def extension_class(self, module: ExtensionModule, extension_name: str) -> type: module_extension = self.module_extension(module, extension_name) + assert module_extension.extension_class is not None t: type = module_extension.extension_class return t diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index df42837796..5cd0ea5c66 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -1,3 +1,4 @@ +import contextlib import re from collections.abc import Mapping from typing import Any, Optional @@ -97,10 +98,8 @@ def parse_traceparent_header(traceparent: str) -> Optional[str]: Reference: W3C Trace Context Specification: https://www.w3.org/TR/trace-context/ """ - try: + with contextlib.suppress(Exception): parts = traceparent.split("-") if len(parts) == 4 and len(parts[1]) == 32: return parts[1] - except Exception: - pass return None diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index c2e4d72cce..afb77d248e 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -9,7 +9,6 @@ import uuid from typing import Any, Optional, cast from flask import current_app -from flask_login import current_user from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config @@ -295,7 +294,7 @@ class IndexingRunner: text_docs, embedding_model_instance=embedding_model_instance, process_rule=processing_rule.to_dict(), - tenant_id=current_user.current_tenant_id, + tenant_id=tenant_id, doc_language=doc_language, preview=True, ) diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 3456770a2e..eb783297c3 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -5,7 +5,7 @@ import os import secrets import urllib.parse from typing import Optional -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse import httpx from pydantic import BaseModel, ValidationError @@ -99,9 +99,37 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta return full_state_data +def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: + """Check if the server supports OAuth 2.0 Resource Discovery.""" + b_scheme, b_netloc, b_path, b_params, b_query, b_fragment = urlparse(server_url, "", True) + url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}" + if b_query: + url_for_resource_discovery += f"?{b_query}" + if b_fragment: + url_for_resource_discovery += f"#{b_fragment}" + try: + headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"} + response = httpx.get(url_for_resource_discovery, headers=headers) + if 200 <= response.status_code < 300: + body = response.json() + if "authorization_server_url" in body: + return True, body["authorization_server_url"][0] + else: + return False, "" + return False, "" + except httpx.RequestError as e: + # Not support resource discovery, fall back to well-known OAuth metadata + return False, "" + + def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]: """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata.""" - url = urljoin(server_url, "/.well-known/oauth-authorization-server") + # First check if the server supports OAuth 2.0 Resource Discovery + support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url) + if support_resource_discovery: + url = oauth_discovery_url + else: + url = urljoin(server_url, "/.well-known/oauth-authorization-server") try: headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION} diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 3f98aa94ae..031f01f411 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -4,7 +4,7 @@ from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError from datetime import timedelta from types import TracebackType -from typing import Any, Generic, Self, TypeVar +from typing import Any, Generic, Optional, Self, TypeVar from httpx import HTTPStatusError from pydantic import BaseModel @@ -209,7 +209,7 @@ class BaseSession( request: SendRequestT, result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, - metadata: MessageMetadata = None, + metadata: Optional[MessageMetadata] = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 99d985a781..49aa8e4498 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -1173,7 +1173,7 @@ class SessionMessage: """A message with specific metadata for transport-specific features.""" message: JSONRPCMessage - metadata: MessageMetadata = None + metadata: Optional[MessageMetadata] = None class OAuthClientMetadata(BaseModel): diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index ace2c1f770..0e1277bc86 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Mapping, Sequence from decimal import Decimal from enum import StrEnum @@ -54,7 +56,7 @@ class LLMUsage(ModelUsage): ) @classmethod - def from_metadata(cls, metadata: dict) -> "LLMUsage": + def from_metadata(cls, metadata: dict) -> LLMUsage: """ Create LLMUsage instance from metadata dictionary with default values. @@ -84,7 +86,7 @@ class LLMUsage(ModelUsage): latency=metadata.get("latency", 0.0), ) - def plus(self, other: "LLMUsage") -> "LLMUsage": + def plus(self, other: LLMUsage) -> LLMUsage: """ Add two LLMUsage instances together. @@ -109,7 +111,7 @@ class LLMUsage(ModelUsage): latency=self.latency + other.latency, ) - def __add__(self, other: "LLMUsage") -> "LLMUsage": + def __add__(self, other: LLMUsage) -> LLMUsage: """ Overload the + operator to add two LLMUsage instances. diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py index b7db0b78bc..68d30112d9 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py @@ -1,10 +1,10 @@ import logging from threading import Lock -from typing import Any +from typing import Any, Optional logger = logging.getLogger(__name__) -_tokenizer: Any = None +_tokenizer: Optional[Any] = None _lock = Lock() diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 332381555b..af51b72cd5 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor from core.helper.encrypter import decrypt_token @@ -11,7 +11,7 @@ from models.api_based_extension import APIBasedExtension class ModerationInputParams(BaseModel): app_id: str = "" - inputs: dict = {} + inputs: dict = Field(default_factory=dict) query: str = "" diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index d8c392d097..99bd0049c0 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from enum import Enum from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.extension.extensible import Extensible, ExtensionModule @@ -16,7 +16,7 @@ class ModerationInputsResult(BaseModel): flagged: bool = False action: ModerationAction preset_response: str = "" - inputs: dict = {} + inputs: dict = Field(default_factory=dict) query: str = "" diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 0f0fe65f27..16c145f936 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -125,11 +125,11 @@ class AdvancedPromptTransform(PromptTransform): if files: prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + prompt_message_contents.append(TextPromptMessageContent(data=prompt)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: @@ -196,16 +196,17 @@ class AdvancedPromptTransform(PromptTransform): query = parser.format(prompt_inputs) + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] if memory and memory_config: prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) if files and query is not None: - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - prompt_message_contents.append(TextPromptMessageContent(data=query)) for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + prompt_message_contents.append(TextPromptMessageContent(data=query)) + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_messages.append(UserPromptMessage(content=query)) @@ -215,27 +216,27 @@ class AdvancedPromptTransform(PromptTransform): last_message = prompt_messages[-1] if prompt_messages else None if last_message and last_message.role == PromptMessageRole.USER: # get last user message content and add files - prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))] for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + prompt_message_contents.append(TextPromptMessageContent(data=cast(str, last_message.content))) last_message.content = prompt_message_contents else: - prompt_message_contents = [TextPromptMessageContent(data="")] # not for query for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + prompt_message_contents.append(TextPromptMessageContent(data="")) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: - prompt_message_contents = [TextPromptMessageContent(data=query)] for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + prompt_message_contents.append(TextPromptMessageContent(data=query)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) elif query: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index e19c6419ca..13f4163d80 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -265,11 +265,11 @@ class SimplePromptTransform(PromptTransform): ) -> UserPromptMessage: if files: prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + prompt_message_contents.append(TextPromptMessageContent(data=prompt)) prompt_message = UserPromptMessage(content=prompt_message_contents) else: diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 789a032654..7406919597 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,3 +1,4 @@ +import contextlib import json from collections import defaultdict from json import JSONDecodeError @@ -624,14 +625,12 @@ class ProviderManager: for variable in provider_credential_secret_variables: if variable in provider_credentials: - try: + with contextlib.suppress(ValueError): provider_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_credentials.get(variable) or "", # type: ignore self.decoding_rsa_key, self.decoding_cipher_rsa, ) - except ValueError: - pass # cache provider credentials provider_credentials_cache.set(credentials=provider_credentials) @@ -672,14 +671,12 @@ class ProviderManager: for variable in model_credential_secret_variables: if variable in provider_model_credentials: - try: + with contextlib.suppress(ValueError): provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa, ) - except ValueError: - pass # cache provider model credentials provider_model_credentials_cache.set(credentials=provider_model_credentials) diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index 2df17181a4..bb61b71bb1 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -105,9 +105,11 @@ class AnalyticdbVectorBySql: conn.close() self.pool = self._create_connection_pool() with self._get_cursor() as cur: + conn = cur.connection try: cur.execute("CREATE EXTENSION IF NOT EXISTS zhparser;") except Exception as e: + conn.rollback() raise RuntimeError( "Failed to create zhparser extension. Please ensure it is available in your AnalyticDB." ) from e @@ -115,6 +117,7 @@ class AnalyticdbVectorBySql: cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)") cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple") except Exception as e: + conn.rollback() if "already exists" not in str(e): raise e cur.execute( diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 1059b855a2..6e8077ffd9 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -1,3 +1,4 @@ +import contextlib import json import logging import queue @@ -214,10 +215,8 @@ class ClickzettaConnectionPool: return connection else: # Connection expired or invalid, close it - try: + with contextlib.suppress(Exception): connection.close() - except Exception: - pass # No valid connection found, create new one return self._create_connection(config) @@ -228,10 +227,8 @@ class ClickzettaConnectionPool: if config_key not in self._pool_locks: # Pool was cleaned up, just close the connection - try: + with contextlib.suppress(Exception): connection.close() - except Exception: - pass return with self._pool_locks[config_key]: @@ -243,10 +240,8 @@ class ClickzettaConnectionPool: logger.debug("Returned ClickZetta connection to pool") else: # Pool full or connection invalid, close it - try: + with contextlib.suppress(Exception): connection.close() - except Exception: - pass def _cleanup_expired_connections(self) -> None: """Clean up expired connections from all pools.""" @@ -265,10 +260,8 @@ class ClickzettaConnectionPool: if current_time - last_used < self._connection_timeout: valid_connections.append((connection, last_used)) else: - try: + with contextlib.suppress(Exception): connection.close() - except Exception: - pass self._pools[config_key] = valid_connections @@ -299,10 +292,8 @@ class ClickzettaConnectionPool: with self._pool_locks[config_key]: pool = self._pools[config_key] for connection, _ in pool: - try: + with contextlib.suppress(Exception): connection.close() - except Exception: - pass pool.clear() diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 04033dec3f..7dfe2e357c 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,5 +1,6 @@ """Abstract interface for document loader implementations.""" +import contextlib from collections.abc import Iterator from typing import Optional, cast @@ -25,12 +26,10 @@ class PdfExtractor(BaseExtractor): def extract(self) -> list[Document]: plaintext_file_exists = False if self._file_cache_key: - try: + with contextlib.suppress(FileNotFoundError): text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] - except FileNotFoundError: - pass documents = list(self.load()) text_list = [] for document in documents: diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index f1fa5dde5c..856a9bce18 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -1,4 +1,5 @@ import base64 +import contextlib import logging from typing import Optional @@ -33,7 +34,7 @@ class UnstructuredEmailExtractor(BaseExtractor): elements = partition_email(filename=self._file_path) # noinspection PyBroadException - try: + with contextlib.suppress(Exception): for element in elements: element_text = element.text.strip() @@ -43,8 +44,6 @@ class UnstructuredEmailExtractor(BaseExtractor): element_decode = base64.b64decode(element_text) soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser") element.text = soup.get_text() - except Exception: - pass from unstructured.chunking.title import chunk_by_title diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py index 21fbb2100f..da03fc67a6 100644 --- a/api/core/rag/extractor/watercrawl/provider.py +++ b/api/core/rag/extractor/watercrawl/provider.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any +from typing import Any, Optional from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient @@ -9,7 +9,7 @@ class WaterCrawlProvider: def __init__(self, api_key, base_url: str | None = None): self.client = WaterCrawlAPIClient(api_key, base_url) - def crawl_url(self, url, options: dict | Any = None) -> dict: + def crawl_url(self, url, options: Optional[dict | Any] = None) -> dict: options = options or {} spider_options = { "max_depth": 1, diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index c37203f1c8..5ecd2f796b 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field class ChildDocument(BaseModel): @@ -15,7 +15,7 @@ class ChildDocument(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = {} + metadata: dict = Field(default_factory=dict) class Document(BaseModel): @@ -28,7 +28,7 @@ class Document(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = {} + metadata: dict = Field(default_factory=dict) provider: Optional[str] = "dify" diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index a25bc65646..cd4af72832 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1012,7 +1012,7 @@ class DatasetRetrieval: def _process_metadata_filter_func( self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list ): - if value is None: + if value is None and condition not in ("empty", "not empty"): return key = f"{metadata_name}_{sequence}" diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index b1a7eacf0e..0e64e45da5 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,4 +1,5 @@ import base64 +import contextlib import enum from collections.abc import Mapping from enum import Enum @@ -227,10 +228,8 @@ class ToolInvokeMessage(BaseModel): @classmethod def decode_blob_message(cls, v): if isinstance(v, dict) and "blob" in v: - try: + with contextlib.suppress(Exception): v["blob"] = base64.b64decode(v["blob"]) - except Exception: - pass return v @field_serializer("message") diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 83444c02d8..10db4d9503 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -1,3 +1,4 @@ +import contextlib import json from collections.abc import Generator, Iterable from copy import deepcopy @@ -69,10 +70,8 @@ class ToolEngine: if parameters and len(parameters) == 1: tool_parameters = {parameters[0].name: tool_parameters} else: - try: + with contextlib.suppress(Exception): tool_parameters = json.loads(tool_parameters) - except Exception: - pass if not isinstance(tool_parameters, dict): raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") @@ -270,14 +269,12 @@ class ToolEngine: if response.meta.get("mime_type"): mimetype = response.meta.get("mime_type") else: - try: + with contextlib.suppress(Exception): url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text) extension = url.suffix guess_type_result, _ = guess_type(f"a{extension}") if guess_type_result: mimetype = guess_type_result - except Exception: - pass if not mimetype: mimetype = "image/jpeg" diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index aceba6e69f..3a9391dbb1 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,3 +1,4 @@ +import contextlib from copy import deepcopy from typing import Any @@ -137,11 +138,9 @@ class ToolParameterConfigurationManager: and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT ): if parameter.name in parameters: - try: - has_secret_input = True + has_secret_input = True + with contextlib.suppress(Exception): parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) - except Exception: - pass if has_secret_input: cache.set(parameters) diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py index 7f4113bb77..f75b7947f1 100644 --- a/api/core/tools/utils/encryption.py +++ b/api/core/tools/utils/encryption.py @@ -1,3 +1,4 @@ +import contextlib from copy import deepcopy from typing import Any, Optional, Protocol @@ -111,14 +112,12 @@ class ProviderConfigEncrypter: for field_name, field in fields.items(): if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field_name in data: - try: + with contextlib.suppress(Exception): # if the value is None or empty string, skip decrypt if not data[field_name]: continue data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) - except Exception: - pass self.provider_config_cache.set(data) return data diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 770c0ef7bd..d8403c2e15 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -80,7 +80,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: else: content = response.text - article = extract_using_readability(content) + article = extract_using_readabilipy(content) if not article.text: return "" @@ -101,7 +101,7 @@ class Article: text: Sequence[dict] -def extract_using_readability(html: str): +def extract_using_readabilipy(html: str): json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True) article = Article( title=json_article.get("title") or "", diff --git a/api/core/variables/types.py b/api/core/variables/types.py index d28fb11401..6629056042 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -126,7 +126,7 @@ class SegmentType(StrEnum): """ if self.is_array_type(): return self._validate_array(value, array_validation) - elif self == SegmentType.NUMBER: + elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]: return isinstance(value, (int, float)) elif self == SegmentType.STRING: return isinstance(value, str) @@ -166,7 +166,6 @@ _ARRAY_TYPES = frozenset( ] ) - _NUMERICAL_TYPES = frozenset( [ SegmentType.NUMBER, diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index a62ffe46c9..e2ec7b17f0 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -22,7 +22,7 @@ class GraphRuntimeState(BaseModel): # # Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent # after a serialization and deserialization round trip. - outputs: dict[str, Any] = {} + outputs: dict[str, Any] = Field(default_factory=dict) node_run_steps: int = 0 """node run steps""" diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 7303b68501..5e5c9f520e 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast from sqlalchemy import Float, and_, func, or_, text from sqlalchemy import cast as sqlalchemy_cast -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -175,7 +175,7 @@ class KnowledgeRetrievalNode(BaseNode): redis_client.zremrangebyscore(key, 0, current_time - 60000) request_count = redis_client.zcard(key) if request_count > knowledge_rate_limit.limit: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: # add ratelimit record rate_limit_log = RateLimitLog( tenant_id=self.tenant_id, @@ -183,7 +183,6 @@ class KnowledgeRetrievalNode(BaseNode): operation="knowledge", ) session.add(rate_limit_log) - session.commit() return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, @@ -389,6 +388,15 @@ class KnowledgeRetrievalNode(BaseNode): "segment_id": segment.id, "retriever_from": "workflow", "score": record.score or 0.0, + "child_chunks": [ + { + "id": str(getattr(chunk, "id", "")), + "content": str(getattr(chunk, "content", "")), + "position": int(getattr(chunk, "position", 0)), + "score": float(getattr(chunk, "score", 0.0)), + } + for chunk in (record.child_chunks or []) + ], "segment_hit_count": segment.hit_count, "segment_word_count": segment.word_count, "segment_position": segment.position, @@ -572,7 +580,7 @@ class KnowledgeRetrievalNode(BaseNode): def _process_metadata_filter_func( self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list ): - if value is None: + if value is None and condition not in ("empty", "not empty"): return key = f"{metadata_name}_{sequence}" diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 4bb62d35a2..e6f8abeba0 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -13,7 +13,7 @@ class ModelConfig(BaseModel): provider: str name: str mode: LLMMode - completion_params: dict[str, Any] = {} + completion_params: dict[str, Any] = Field(default_factory=dict) class ContextConfig(BaseModel): diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 655de9362f..9a288c6133 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -313,30 +313,31 @@ class LoopNode(BaseNode): and event.node_type == NodeType.LOOP_END and not isinstance(event, NodeRunStreamChunkEvent) ): - check_break_result = True + # Check if variables in break conditions exist and process conditions + # Allow loop internal variables to be used in break conditions + available_conditions = [] + for condition in break_conditions: + variable = self.graph_runtime_state.variable_pool.get(condition.variable_selector) + if variable: + available_conditions.append(condition) + + # Process conditions if at least one variable is available + if available_conditions: + input_conditions, group_result, check_break_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=available_conditions, + operator=logical_operator, + ) + if check_break_result: + break + else: + check_break_result = True yield self._handle_event_metadata(event=event, iter_run_index=current_index) break if isinstance(event, NodeRunSucceededEvent): yield self._handle_event_metadata(event=event, iter_run_index=current_index) - # Check if all variables in break conditions exist - exists_variable = False - for condition in break_conditions: - if not self.graph_runtime_state.variable_pool.get(condition.variable_selector): - exists_variable = False - break - else: - exists_variable = True - if exists_variable: - input_conditions, group_result, check_break_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) - if check_break_result: - break - elif isinstance(event, BaseGraphEvent): if isinstance(event, GraphRunFailedEvent): # Loop run failed diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 45c5e0a62c..49c4c142e1 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -1,3 +1,4 @@ +import contextlib import json import logging import uuid @@ -666,10 +667,8 @@ class ParameterExtractorNode(BaseNode): if result[idx] == "{" or result[idx] == "[": json_str = extract_json(result[idx:]) if json_str: - try: + with contextlib.suppress(Exception): return cast(dict, json.loads(json_str)) - except Exception: - pass logger.info("extra error: %s", result) return None @@ -686,10 +685,9 @@ class ParameterExtractorNode(BaseNode): if result[idx] == "{" or result[idx] == "[": json_str = extract_json(result[idx:]) if json_str: - try: + with contextlib.suppress(Exception): return cast(dict, json.loads(json_str)) - except Exception: - pass + logger.info("extra error: %s", result) return None diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index c607161e2a..1b0321f42e 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -1,3 +1,4 @@ +import contextlib import logging import time @@ -38,12 +39,11 @@ def handle(sender, **kwargs): db.session.add(document) db.session.commit() - try: - indexing_runner = IndexingRunner() - indexing_runner.run(documents) - end_at = time.perf_counter() - logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logging.info(click.style(str(ex), fg="yellow")) - except Exception: - pass + with contextlib.suppress(Exception): + try: + indexing_runner = IndexingRunner() + indexing_runner.run(documents) + end_at = time.perf_counter() + logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index a8f025a750..3fd9633e79 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -1,4 +1,5 @@ import atexit +import contextlib import logging import os import platform @@ -106,7 +107,7 @@ def init_app(app: DifyApp): """Custom logging handler that creates spans for logging.exception() calls""" def emit(self, record: logging.LogRecord): - try: + with contextlib.suppress(Exception): if record.exc_info: tracer = get_tracer_provider().get_tracer("dify.exception.logging") with tracer.start_as_current_span( @@ -126,9 +127,6 @@ def init_app(app: DifyApp): if record.exc_info[0]: span.set_attribute("exception.type", record.exc_info[0].__name__) - except Exception: - pass - from opentelemetry import trace from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index f5f544679f..1b22886fc1 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -3,7 +3,7 @@ import logging import ssl from collections.abc import Callable from datetime import timedelta -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Optional, Union import redis from redis import RedisError @@ -246,7 +246,7 @@ def init_app(app: DifyApp): app.extensions["redis"] = redis_client -def redis_fallback(default_return: Any = None): +def redis_fallback(default_return: Optional[Any] = None): """ decorator to handle Redis operation exceptions and return a default value when Redis is unavailable. diff --git a/api/models/workflow.py b/api/models/workflow.py index ed23cb9c16..7986c36a74 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 import sqlalchemy as sa -from flask_login import current_user from sqlalchemy import DateTime, orm from core.file.constants import maybe_file_object @@ -18,7 +17,6 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIAB from core.workflow.nodes.enums import NodeType from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now -from libs.helper import extract_tenant_id from ._workflow_exc import NodeNotFoundError, WorkflowDataError @@ -363,8 +361,8 @@ class Workflow(Base): if self._environment_variables is None: self._environment_variables = "{}" - # Get tenant_id from current_user (Account or EndUser) - tenant_id = extract_tenant_id(current_user) + # Use workflow.tenant_id to avoid relying on request user in background threads + tenant_id = self.tenant_id if not tenant_id: return [] @@ -394,8 +392,8 @@ class Workflow(Base): self._environment_variables = "{}" return - # Get tenant_id from current_user (Account or EndUser) - tenant_id = extract_tenant_id(current_user) + # Use workflow.tenant_id to avoid relying on request user in background threads + tenant_id = self.tenant_id if not tenant_id: self._environment_variables = "{}" diff --git a/api/pyproject.toml b/api/pyproject.toml index ce642aa9c8..cf5ad8e7d2 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "cachetools~=5.3.0", "celery~=5.5.2", "chardet~=5.1.0", - "flask~=3.1.0", + "flask~=3.1.2", "flask-compress~=1.17", "flask-cors~=6.0.0", "flask-login~=0.6.3", diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 4f3dd3c762..712ef4c601 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,3 +1,4 @@ +import contextlib from collections.abc import Callable, Sequence from typing import Any, Optional, Union @@ -142,13 +143,11 @@ class ConversationService: raise MessageNotExistsError() # generate conversation name - try: + with contextlib.suppress(Exception): name = LLMGenerator.generate_conversation_name( app_model.tenant_id, message.query, conversation.id, app_model.id ) conversation.name = name - except Exception: - pass db.session.commit() diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 59d5b50e23..f245dd7527 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.tools.entities.api_entities import ToolProviderTypeApiLiteral from core.tools.tool_manager import ToolManager @@ -9,7 +10,7 @@ logger = logging.getLogger(__name__) class ToolCommonService: @staticmethod - def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral = None): + def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None): """ list tool providers diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index afcf1f7621..00b02f8091 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -402,7 +402,7 @@ class WorkflowConverter: ) role_prefix = None - prompts: Any = None + prompts: Optional[Any] = None # Chat Model if model_config.mode == LLMMode.CHAT.value: diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index 4af35a8bef..be5b4de5a2 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -1,5 +1,6 @@ import os from collections import UserDict +from typing import Optional from unittest.mock import MagicMock import pytest @@ -21,7 +22,7 @@ class MockBaiduVectorDBClass: def mock_vector_db_client( self, config=None, - adapter: HTTPAdapter = None, + adapter: Optional[HTTPAdapter] = None, ): self.conn = MagicMock() self._config = MagicMock() diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index ae5f9761b4..02f658aad6 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -23,7 +23,7 @@ class MockTcvectordbClass: key="", read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, timeout=10, - adapter: HTTPAdapter = None, + adapter: Optional[HTTPAdapter] = None, pool_size: int = 2, proxies: Optional[dict] = None, password: Optional[str] = None, @@ -72,11 +72,11 @@ class MockTcvectordbClass: shard: int, replicas: int, description: Optional[str] = None, - index: Index = None, - embedding: Embedding = None, + index: Optional[Index] = None, + embedding: Optional[Embedding] = None, timeout: Optional[float] = None, ttl_config: Optional[dict] = None, - filter_index_config: FilterIndexConfig = None, + filter_index_config: Optional[FilterIndexConfig] = None, indexes: Optional[list[IndexField]] = None, ) -> RPCCollection: return RPCCollection( @@ -113,7 +113,7 @@ class MockTcvectordbClass: database_name: str, collection_name: str, vectors: list[list[float]], - filter: Filter = None, + filter: Optional[Filter] = None, params=None, retrieve_vector: bool = False, limit: int = 10, @@ -128,7 +128,7 @@ class MockTcvectordbClass: collection_name: str, ann: Optional[Union[list[AnnSearch], AnnSearch]] = None, match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None, - filter: Union[Filter, str] = None, + filter: Optional[Union[Filter, str]] = None, rerank: Optional[Rerank] = None, retrieve_vector: Optional[bool] = None, output_fields: Optional[list[str]] = None, @@ -158,7 +158,7 @@ class MockTcvectordbClass: database_name: str, collection_name: str, document_ids: Optional[list[str]] = None, - filter: Filter = None, + filter: Optional[Filter] = None, timeout: Optional[float] = None, ): return {"code": 0, "msg": "operation success"} diff --git a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py index 8b57132772..21de8be6e3 100644 --- a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py +++ b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py @@ -1,3 +1,4 @@ +import contextlib import os import pytest @@ -44,10 +45,8 @@ class TestClickzettaVector(AbstractVectorTest): yield vector # Cleanup: delete the test collection - try: + with contextlib.suppress(Exception): vector.delete() - except Exception: - pass def test_clickzetta_vector_basic_operations(self, vector_store): """Test basic CRUD operations on Clickzetta vector store.""" diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py new file mode 100644 index 0000000000..2d5cdf426d --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -0,0 +1,1192 @@ +from unittest.mock import patch + +import pytest +from faker import Faker +from werkzeug.exceptions import NotFound + +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset +from models.model import App, Tag, TagBinding +from services.tag_service import TagService + + +class TestTagService: + """Integration tests for TagService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.tag_service.current_user") as mock_current_user, + ): + # Setup default mock returns + mock_current_user.current_tenant_id = "test-tenant-id" + mock_current_user.id = "test-user-id" + + yield { + "current_user": mock_current_user, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + # Update mock to use real tenant ID + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + return account, tenant + + def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): + """ + Helper method to create a test dataset for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + tenant_id: Tenant ID for the dataset + + Returns: + Dataset: Created dataset instance + """ + fake = Faker() + + dataset = Dataset( + name=fake.company(), + description=fake.text(max_nb_chars=100), + provider="vendor", + permission="only_me", + data_source_type="upload", + indexing_technique="high_quality", + tenant_id=tenant_id, + created_by=mock_external_service_dependencies["current_user"].id, + ) + + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + return dataset + + def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): + """ + Helper method to create a test app for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + tenant_id: Tenant ID for the app + + Returns: + App: Created app instance + """ + fake = Faker() + + app = App( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + enable_site=False, + enable_api=False, + tenant_id=tenant_id, + created_by=mock_external_service_dependencies["current_user"].id, + ) + + from extensions.ext_database import db + + db.session.add(app) + db.session.commit() + + return app + + def _create_test_tags( + self, db_session_with_containers, mock_external_service_dependencies, tenant_id, tag_type, count=3 + ): + """ + Helper method to create test tags for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + tenant_id: Tenant ID for the tags + tag_type: Type of tags to create + count: Number of tags to create + + Returns: + list: List of created tag instances + """ + fake = Faker() + tags = [] + + for i in range(count): + tag = Tag( + name=f"tag_{tag_type}_{i}_{fake.word()}", + type=tag_type, + tenant_id=tenant_id, + created_by=mock_external_service_dependencies["current_user"].id, + ) + tags.append(tag) + + from extensions.ext_database import db + + for tag in tags: + db.session.add(tag) + db.session.commit() + + return tags + + def _create_test_tag_bindings( + self, db_session_with_containers, mock_external_service_dependencies, tags, target_id, tenant_id + ): + """ + Helper method to create test tag bindings for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + tags: List of tags to bind + target_id: Target ID to bind tags to + tenant_id: Tenant ID for the bindings + + Returns: + list: List of created tag binding instances + """ + tag_bindings = [] + + for tag in tags: + tag_binding = TagBinding( + tag_id=tag.id, + target_id=target_id, + tenant_id=tenant_id, + created_by=mock_external_service_dependencies["current_user"].id, + ) + tag_bindings.append(tag_binding) + + from extensions.ext_database import db + + for tag_binding in tag_bindings: + db.session.add(tag_binding) + db.session.commit() + + return tag_bindings + + def test_get_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of tags with binding count. + + This test verifies: + - Proper tag retrieval with binding count + - Correct filtering by tag type and tenant + - Proper ordering by creation date + - Binding count calculation + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 3 + ) + + # Create dataset and bind tags + dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) + self._create_test_tag_bindings( + db_session_with_containers, mock_external_service_dependencies, tags[:2], dataset.id, tenant.id + ) + + # Act: Execute the method under test + result = TagService.get_tags("knowledge", tenant.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 3 + + # Verify tag data structure + for tag_result in result: + assert hasattr(tag_result, "id") + assert hasattr(tag_result, "type") + assert hasattr(tag_result, "name") + assert hasattr(tag_result, "binding_count") + assert tag_result.type == "knowledge" + + # Verify binding count + tag_with_bindings = next((t for t in result if t.binding_count > 0), None) + assert tag_with_bindings is not None + assert tag_with_bindings.binding_count >= 1 + + # Verify ordering (newest first) - note: created_at is not in SELECT but used in ORDER BY + # The ordering is handled by the database, we just verify the results are returned + assert len(result) == 3 + + def test_get_tags_with_keyword_filter(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag retrieval with keyword filtering. + + This test verifies: + - Proper keyword filtering functionality + - Case-insensitive search + - Partial match functionality + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags with specific names + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 3 + ) + + # Update tag names to make them searchable + from extensions.ext_database import db + + tags[0].name = "python_development" + tags[1].name = "machine_learning" + tags[2].name = "web_development" + db.session.commit() + + # Act: Execute the method under test with keyword filter + result = TagService.get_tags("app", tenant.id, keyword="development") + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 2 # Should find python_development and web_development + + # Verify filtered results contain the keyword + for tag_result in result: + assert "development" in tag_result.name.lower() + + # Verify no results for non-matching keyword + result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent") + assert len(result_no_match) == 0 + + def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag retrieval when no tags exist. + + This test verifies: + - Proper handling of empty tag sets + - Correct return value for no results + """ + # Arrange: Create test data without tags + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute the method under test + result = TagService.get_tags("knowledge", tenant.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 0 + assert isinstance(result, list) + + def test_get_target_ids_by_tag_ids_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of target IDs by tag IDs. + + This test verifies: + - Proper target ID retrieval for valid tag IDs + - Correct filtering by tag type and tenant + - Proper handling of tag bindings + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 3 + ) + + # Create multiple datasets and bind tags + datasets = [] + for i in range(2): + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, tenant.id + ) + datasets.append(dataset) + # Bind first two tags to first dataset, last tag to second dataset + tags_to_bind = tags[:2] if i == 0 else tags[2:] + self._create_test_tag_bindings( + db_session_with_containers, mock_external_service_dependencies, tags_to_bind, dataset.id, tenant.id + ) + + # Act: Execute the method under test + tag_ids = [tag.id for tag in tags] + result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 3 # Should find 3 target IDs (2 from first dataset, 1 from second) + + # Verify all dataset IDs are returned + dataset_ids = [dataset.id for dataset in datasets] + for target_id in result: + assert target_id in dataset_ids + + # Verify the first dataset appears twice (for the first two tags) + first_dataset_count = result.count(datasets[0].id) + assert first_dataset_count == 2 + + # Verify the second dataset appears once (for the last tag) + second_dataset_count = result.count(datasets[1].id) + assert second_dataset_count == 1 + + def test_get_target_ids_by_tag_ids_empty_tag_ids( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test target ID retrieval with empty tag IDs list. + + This test verifies: + - Proper handling of empty tag IDs + - Correct return value for empty input + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute the method under test with empty tag IDs + result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, []) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 0 + assert isinstance(result, list) + + def test_get_target_ids_by_tag_ids_no_matching_tags( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test target ID retrieval when no tags match the criteria. + + This test verifies: + - Proper handling of non-existent tag IDs + - Correct return value for no matches + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent tag IDs + import uuid + + non_existent_tag_ids = [str(uuid.uuid4()), str(uuid.uuid4())] + + # Act: Execute the method under test + result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, non_existent_tag_ids) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 0 + assert isinstance(result, list) + + def test_get_tag_by_tag_name_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of tags by tag name. + + This test verifies: + - Proper tag retrieval by name + - Correct filtering by tag type and tenant + - Proper return value structure + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags with specific names + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 2 + ) + + # Update tag names to make them searchable + from extensions.ext_database import db + + tags[0].name = "python_tag" + tags[1].name = "ml_tag" + db.session.commit() + + # Act: Execute the method under test + result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag") + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 1 + assert result[0].name == "python_tag" + assert result[0].type == "app" + assert result[0].tenant_id == tenant.id + + def test_get_tag_by_tag_name_no_matches(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag retrieval by name when no matches exist. + + This test verifies: + - Proper handling of non-existent tag names + - Correct return value for no matches + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute the method under test with non-existent tag name + result = TagService.get_tag_by_tag_name("knowledge", tenant.id, "nonexistent_tag") + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 0 + assert isinstance(result, list) + + def test_get_tag_by_tag_name_empty_parameters(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag retrieval by name with empty parameters. + + This test verifies: + - Proper handling of empty tag type + - Proper handling of empty tag name + - Correct return value for invalid input + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute the method under test with empty parameters + result_empty_type = TagService.get_tag_by_tag_name("", tenant.id, "test_tag") + result_empty_name = TagService.get_tag_by_tag_name("knowledge", tenant.id, "") + + # Assert: Verify the expected outcomes + assert result_empty_type is not None + assert len(result_empty_type) == 0 + assert result_empty_name is not None + assert len(result_empty_name) == 0 + + def test_get_tags_by_target_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of tags by target ID. + + This test verifies: + - Proper tag retrieval for a specific target + - Correct filtering by tag type and tenant + - Proper join with tag bindings + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 3 + ) + + # Create app and bind tags + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) + self._create_test_tag_bindings( + db_session_with_containers, mock_external_service_dependencies, tags, app.id, tenant.id + ) + + # Act: Execute the method under test + result = TagService.get_tags_by_target_id("app", tenant.id, app.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 3 + + # Verify all tags are returned + for tag in result: + assert tag.type == "app" + assert tag.tenant_id == tenant.id + assert tag.id in [t.id for t in tags] + + def test_get_tags_by_target_id_no_bindings(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag retrieval by target ID when no tags are bound. + + This test verifies: + - Proper handling of targets with no tag bindings + - Correct return value for no results + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create app without binding any tags + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) + + # Act: Execute the method under test + result = TagService.get_tags_by_target_id("app", tenant.id, app.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 0 + assert isinstance(result, list) + + def test_save_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tag creation. + + This test verifies: + - Proper tag creation with all required fields + - Correct database state after creation + - Proper UUID generation + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + tag_args = {"name": "test_tag_name", "type": "knowledge"} + + # Act: Execute the method under test + result = TagService.save_tags(tag_args) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.name == "test_tag_name" + assert result.type == "knowledge" + assert result.tenant_id == tenant.id + assert result.created_by == account.id + assert result.id is not None + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.id is not None + + # Verify tag was actually saved to database + saved_tag = db.session.query(Tag).where(Tag.id == result.id).first() + assert saved_tag is not None + assert saved_tag.name == "test_tag_name" + + def test_save_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag creation with duplicate name. + + This test verifies: + - Proper error handling for duplicate tag names + - Correct exception type and message + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create first tag + tag_args = {"name": "duplicate_tag", "type": "app"} + TagService.save_tags(tag_args) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + TagService.save_tags(tag_args) + assert "Tag name already exists" in str(exc_info.value) + + def test_update_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tag update. + + This test verifies: + - Proper tag update with new name + - Correct database state after update + - Proper error handling for non-existent tags + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create a tag to update + tag_args = {"name": "original_name", "type": "knowledge"} + tag = TagService.save_tags(tag_args) + + # Update args + update_args = {"name": "updated_name", "type": "knowledge"} + + # Act: Execute the method under test + result = TagService.update_tags(update_args, tag.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.name == "updated_name" + assert result.type == "knowledge" + assert result.id == tag.id + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.name == "updated_name" + + # Verify tag was actually updated in database + updated_tag = db.session.query(Tag).where(Tag.id == tag.id).first() + assert updated_tag is not None + assert updated_tag.name == "updated_name" + + def test_update_tags_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag update for non-existent tag. + + This test verifies: + - Proper error handling for non-existent tags + - Correct exception type + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent tag ID + import uuid + + non_existent_tag_id = str(uuid.uuid4()) + + update_args = {"name": "updated_name", "type": "knowledge"} + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + TagService.update_tags(update_args, non_existent_tag_id) + assert "Tag not found" in str(exc_info.value) + + def test_update_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag update with duplicate name. + + This test verifies: + - Proper error handling for duplicate tag names during update + - Correct exception type and message + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create two tags + tag1_args = {"name": "first_tag", "type": "app"} + tag1 = TagService.save_tags(tag1_args) + + tag2_args = {"name": "second_tag", "type": "app"} + tag2 = TagService.save_tags(tag2_args) + + # Try to update second tag with first tag's name + update_args = {"name": "first_tag", "type": "app"} + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + TagService.update_tags(update_args, tag2.id) + assert "Tag name already exists" in str(exc_info.value) + + def test_get_tag_binding_count_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of tag binding count. + + This test verifies: + - Proper binding count calculation + - Correct handling of tags with no bindings + - Proper database query execution + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 2 + ) + + # Create dataset and bind first tag + dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) + self._create_test_tag_bindings( + db_session_with_containers, mock_external_service_dependencies, [tags[0]], dataset.id, tenant.id + ) + + # Act: Execute the method under test + result_tag_with_bindings = TagService.get_tag_binding_count(tags[0].id) + result_tag_without_bindings = TagService.get_tag_binding_count(tags[1].id) + + # Assert: Verify the expected outcomes + assert result_tag_with_bindings == 1 + assert result_tag_without_bindings == 0 + + def test_get_tag_binding_count_non_existent_tag( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test binding count retrieval for non-existent tag. + + This test verifies: + - Proper handling of non-existent tag IDs + - Correct return value for non-existent tags + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent tag ID + import uuid + + non_existent_tag_id = str(uuid.uuid4()) + + # Act: Execute the method under test + result = TagService.get_tag_binding_count(non_existent_tag_id) + + # Assert: Verify the expected outcomes + assert result == 0 + + def test_delete_tag_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tag deletion. + + This test verifies: + - Proper tag deletion from database + - Proper cleanup of associated tag bindings + - Correct database state after deletion + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tag with bindings + tag = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 1 + )[0] + + # Create app and bind tag + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) + self._create_test_tag_bindings( + db_session_with_containers, mock_external_service_dependencies, [tag], app.id, tenant.id + ) + + # Verify tag and binding exist before deletion + from extensions.ext_database import db + + tag_before = db.session.query(Tag).where(Tag.id == tag.id).first() + assert tag_before is not None + + binding_before = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() + assert binding_before is not None + + # Act: Execute the method under test + TagService.delete_tag(tag.id) + + # Assert: Verify the expected outcomes + # Verify tag was deleted + tag_after = db.session.query(Tag).where(Tag.id == tag.id).first() + assert tag_after is None + + # Verify tag binding was deleted + binding_after = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() + assert binding_after is None + + def test_delete_tag_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag deletion for non-existent tag. + + This test verifies: + - Proper error handling for non-existent tags + - Correct exception type + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent tag ID + import uuid + + non_existent_tag_id = str(uuid.uuid4()) + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + TagService.delete_tag(non_existent_tag_id) + assert "Tag not found" in str(exc_info.value) + + def test_save_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tag binding creation. + + This test verifies: + - Proper tag binding creation + - Correct handling of duplicate bindings + - Proper database state after creation + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 2 + ) + + # Create dataset + dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) + + # Act: Execute the method under test + binding_args = {"type": "knowledge", "target_id": dataset.id, "tag_ids": [tag.id for tag in tags]} + TagService.save_tag_binding(binding_args) + + # Assert: Verify the expected outcomes + from extensions.ext_database import db + + # Verify tag bindings were created + for tag in tags: + binding = ( + db.session.query(TagBinding) + .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) + .first() + ) + assert binding is not None + assert binding.tenant_id == tenant.id + assert binding.created_by == account.id + + def test_save_tag_binding_duplicate_handling(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag binding creation with duplicate bindings. + + This test verifies: + - Proper handling of duplicate tag bindings + - No errors when trying to create existing bindings + - Correct database state after operation + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tag + tag = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 1 + )[0] + + # Create app + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) + + # Create first binding + binding_args = {"type": "app", "target_id": app.id, "tag_ids": [tag.id]} + TagService.save_tag_binding(binding_args) + + # Act: Try to create duplicate binding + TagService.save_tag_binding(binding_args) + + # Assert: Verify the expected outcomes + from extensions.ext_database import db + + # Verify only one binding exists + bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() + assert len(bindings) == 1 + + def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag binding creation with invalid target type. + + This test verifies: + - Proper error handling for invalid target types + - Correct exception type + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tag + tag = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 1 + )[0] + + # Create non-existent target ID + import uuid + + non_existent_target_id = str(uuid.uuid4()) + + # Act & Assert: Verify proper error handling + binding_args = {"type": "invalid_type", "target_id": non_existent_target_id, "tag_ids": [tag.id]} + + with pytest.raises(NotFound) as exc_info: + TagService.save_tag_binding(binding_args) + assert "Invalid binding type" in str(exc_info.value) + + def test_delete_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tag binding deletion. + + This test verifies: + - Proper tag binding deletion from database + - Correct database state after deletion + - Proper error handling for non-existent bindings + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tag + tag = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 1 + )[0] + + # Create dataset and bind tag + dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) + self._create_test_tag_bindings( + db_session_with_containers, mock_external_service_dependencies, [tag], dataset.id, tenant.id + ) + + # Verify binding exists before deletion + from extensions.ext_database import db + + binding_before = ( + db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() + ) + assert binding_before is not None + + # Act: Execute the method under test + delete_args = {"type": "knowledge", "target_id": dataset.id, "tag_id": tag.id} + TagService.delete_tag_binding(delete_args) + + # Assert: Verify the expected outcomes + # Verify tag binding was deleted + binding_after = ( + db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() + ) + assert binding_after is None + + def test_delete_tag_binding_non_existent_binding( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test tag binding deletion for non-existent binding. + + This test verifies: + - Proper handling of non-existent tag bindings + - No errors when trying to delete non-existent bindings + - Correct database state after operation + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tag and dataset without binding + tag = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 1 + )[0] + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) + + # Act: Try to delete non-existent binding + delete_args = {"type": "app", "target_id": app.id, "tag_id": tag.id} + TagService.delete_tag_binding(delete_args) + + # Assert: Verify the expected outcomes + # No error should be raised, and database state should remain unchanged + from extensions.ext_database import db + + bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() + assert len(bindings) == 0 + + def test_check_target_exists_knowledge_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful target existence check for knowledge type. + + This test verifies: + - Proper validation of knowledge dataset existence + - Correct error handling for non-existent datasets + - Proper tenant filtering + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create dataset + dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) + + # Act: Execute the method under test + TagService.check_target_exists("knowledge", dataset.id) + + # Assert: Verify the expected outcomes + # No exception should be raised for existing dataset + + def test_check_target_exists_knowledge_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test target existence check for non-existent knowledge dataset. + + This test verifies: + - Proper error handling for non-existent knowledge datasets + - Correct exception type and message + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent dataset ID + import uuid + + non_existent_dataset_id = str(uuid.uuid4()) + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + TagService.check_target_exists("knowledge", non_existent_dataset_id) + assert "Dataset not found" in str(exc_info.value) + + def test_check_target_exists_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful target existence check for app type. + + This test verifies: + - Proper validation of app existence + - Correct error handling for non-existent apps + - Proper tenant filtering + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create app + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) + + # Act: Execute the method under test + TagService.check_target_exists("app", app.id) + + # Assert: Verify the expected outcomes + # No exception should be raised for existing app + + def test_check_target_exists_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test target existence check for non-existent app. + + This test verifies: + - Proper error handling for non-existent apps + - Correct exception type and message + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent app ID + import uuid + + non_existent_app_id = str(uuid.uuid4()) + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + TagService.check_target_exists("app", non_existent_app_id) + assert "App not found" in str(exc_info.value) + + def test_check_target_exists_invalid_type(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test target existence check for invalid type. + + This test verifies: + - Proper error handling for invalid target types + - Correct exception type and message + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent target ID + import uuid + + non_existent_target_id = str(uuid.uuid4()) + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + TagService.check_target_exists("invalid_type", non_existent_target_id) + assert "Invalid binding type" in str(exc_info.value) diff --git a/api/tests/unit_tests/core/mcp/client/test_sse.py b/api/tests/unit_tests/core/mcp/client/test_sse.py index 880a0d4940..aadd366762 100644 --- a/api/tests/unit_tests/core/mcp/client/test_sse.py +++ b/api/tests/unit_tests/core/mcp/client/test_sse.py @@ -1,3 +1,4 @@ +import contextlib import json import queue import threading @@ -124,13 +125,10 @@ def test_sse_client_connection_validation(): mock_event_source.iter_sse.return_value = [endpoint_event] # Test connection - try: + with contextlib.suppress(Exception): with sse_client(test_url) as (read_queue, write_queue): assert read_queue is not None assert write_queue is not None - except Exception as e: - # Connection might fail due to mocking, but we're testing the validation logic - pass def test_sse_client_error_handling(): @@ -178,7 +176,7 @@ def test_sse_client_timeout_configuration(): mock_event_source.iter_sse.return_value = [] mock_sse_connect.return_value.__enter__.return_value = mock_event_source - try: + with contextlib.suppress(Exception): with sse_client( test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout ) as (read_queue, write_queue): @@ -190,9 +188,6 @@ def test_sse_client_timeout_configuration(): assert call_args is not None timeout_arg = call_args[1]["timeout"] assert timeout_arg.read == custom_sse_timeout - except Exception: - # Connection might fail due to mocking, but we tested the configuration - pass def test_sse_transport_endpoint_validation(): @@ -251,12 +246,10 @@ def test_sse_client_queue_cleanup(): # Mock connection that raises an exception mock_sse_connect.side_effect = Exception("Connection failed") - try: + with contextlib.suppress(Exception): with sse_client(test_url) as (rq, wq): read_queue = rq write_queue = wq - except Exception: - pass # Expected to fail # Queues should be cleaned up even on exception # Note: In real implementation, cleanup should put None to signal shutdown @@ -283,11 +276,9 @@ def test_sse_client_headers_propagation(): mock_event_source.iter_sse.return_value = [] mock_sse_connect.return_value.__enter__.return_value = mock_event_source - try: + with contextlib.suppress(Exception): with sse_client(test_url, headers=custom_headers): pass - except Exception: - pass # Expected due to mocking # Verify headers were passed to client factory mock_client_factory.assert_called_with(headers=custom_headers) diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index f6d22690d1..8abed0a3f9 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -164,7 +164,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg ) assert isinstance(prompt_messages[3].content, list) assert len(prompt_messages[3].content) == 2 - assert prompt_messages[3].content[1].data == files[0].remote_url + assert prompt_messages[3].content[0].data == files[0].remote_url @pytest.fixture diff --git a/api/tests/unit_tests/core/tools/utils/test_encryption.py b/api/tests/unit_tests/core/tools/utils/test_encryption.py new file mode 100644 index 0000000000..6425ab0b8d --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_encryption.py @@ -0,0 +1,181 @@ +import copy +from unittest.mock import patch + +import pytest + +from core.entities.provider_entities import BasicProviderConfig +from core.tools.utils.encryption import ProviderConfigEncrypter + + +# --------------------------- +# A no-op cache +# --------------------------- +class NoopCache: + """Simple cache stub: always returns None, does nothing for set/delete.""" + + def get(self): + return None + + def set(self, config): + pass + + def delete(self): + pass + + +@pytest.fixture +def secret_field() -> BasicProviderConfig: + """A SECRET_INPUT field named 'password'.""" + return BasicProviderConfig( + name="password", + type=BasicProviderConfig.Type.SECRET_INPUT, + ) + + +@pytest.fixture +def normal_field() -> BasicProviderConfig: + """A TEXT_INPUT field named 'username'.""" + return BasicProviderConfig( + name="username", + type=BasicProviderConfig.Type.TEXT_INPUT, + ) + + +@pytest.fixture +def encrypter_obj(secret_field, normal_field): + """ + Build ProviderConfigEncrypter with: + - tenant_id = tenant123 + - one secret field (password) and one normal field (username) + - NoopCache as cache + """ + return ProviderConfigEncrypter( + tenant_id="tenant123", + config=[secret_field, normal_field], + provider_config_cache=NoopCache(), + ) + + +# ============================================================ +# ProviderConfigEncrypter.encrypt() +# ============================================================ + + +def test_encrypt_only_secret_is_encrypted_and_non_secret_unchanged(encrypter_obj): + """ + Secret field should be encrypted, non-secret field unchanged. + Verify encrypt_token called only for secret field. + Also check deep copy (input not modified). + """ + data_in = {"username": "alice", "password": "plain_pwd"} + data_copy = copy.deepcopy(data_in) + + with patch("core.tools.utils.encryption.encrypter.encrypt_token", return_value="CIPHERTEXT") as mock_encrypt: + out = encrypter_obj.encrypt(data_in) + + assert out["username"] == "alice" + assert out["password"] == "CIPHERTEXT" + mock_encrypt.assert_called_once_with("tenant123", "plain_pwd") + assert data_in == data_copy # deep copy semantics + + +def test_encrypt_missing_secret_key_is_ok(encrypter_obj): + """If secret field missing in input, no error and no encryption called.""" + with patch("core.tools.utils.encryption.encrypter.encrypt_token") as mock_encrypt: + out = encrypter_obj.encrypt({"username": "alice"}) + assert out["username"] == "alice" + mock_encrypt.assert_not_called() + + +# ============================================================ +# ProviderConfigEncrypter.mask_tool_credentials() +# ============================================================ + + +@pytest.mark.parametrize( + ("raw", "prefix", "suffix"), + [ + ("longsecret", "lo", "et"), + ("abcdefg", "ab", "fg"), + ("1234567", "12", "67"), + ], +) +def test_mask_tool_credentials_long_secret(encrypter_obj, raw, prefix, suffix): + """ + For length > 6: keep first 2 and last 2, mask middle with '*'. + """ + data_in = {"username": "alice", "password": raw} + data_copy = copy.deepcopy(data_in) + + out = encrypter_obj.mask_tool_credentials(data_in) + masked = out["password"] + + assert masked.startswith(prefix) + assert masked.endswith(suffix) + assert "*" in masked + assert len(masked) == len(raw) + assert data_in == data_copy # deep copy semantics + + +@pytest.mark.parametrize("raw", ["", "1", "12", "123", "123456"]) +def test_mask_tool_credentials_short_secret(encrypter_obj, raw): + """ + For length <= 6: fully mask with '*' of same length. + """ + out = encrypter_obj.mask_tool_credentials({"password": raw}) + assert out["password"] == ("*" * len(raw)) + + +def test_mask_tool_credentials_missing_key_noop(encrypter_obj): + """If secret key missing, leave other fields unchanged.""" + data_in = {"username": "alice"} + data_copy = copy.deepcopy(data_in) + + out = encrypter_obj.mask_tool_credentials(data_in) + assert out["username"] == "alice" + assert data_in == data_copy + + +# ============================================================ +# ProviderConfigEncrypter.decrypt() +# ============================================================ + + +def test_decrypt_normal_flow(encrypter_obj): + """ + Normal decrypt flow: + - decrypt_token called for secret field + - secret replaced with decrypted value + - non-secret unchanged + """ + data_in = {"username": "alice", "password": "ENC"} + data_copy = copy.deepcopy(data_in) + + with patch("core.tools.utils.encryption.encrypter.decrypt_token", return_value="PLAIN") as mock_decrypt: + out = encrypter_obj.decrypt(data_in) + + assert out["username"] == "alice" + assert out["password"] == "PLAIN" + mock_decrypt.assert_called_once_with("tenant123", "ENC") + assert data_in == data_copy # deep copy semantics + + +@pytest.mark.parametrize("empty_val", ["", None]) +def test_decrypt_skip_empty_values(encrypter_obj, empty_val): + """Skip decrypt if value is empty or None, keep original.""" + with patch("core.tools.utils.encryption.encrypter.decrypt_token") as mock_decrypt: + out = encrypter_obj.decrypt({"password": empty_val}) + + mock_decrypt.assert_not_called() + assert out["password"] == empty_val + + +def test_decrypt_swallow_exception_and_keep_original(encrypter_obj): + """ + If decrypt_token raises, exception should be swallowed, + and original value preserved. + """ + with patch("core.tools.utils.encryption.encrypter.decrypt_token", side_effect=Exception("boom")): + out = encrypter_obj.decrypt({"password": "ENC_ERR"}) + + assert out["password"] == "ENC_ERR" diff --git a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py index c17308baad..20f753786d 100644 --- a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py +++ b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py @@ -1,6 +1,293 @@ -from core.tools.utils.web_reader_tool import get_image_upload_file_ids +import pytest + +from core.tools.utils.web_reader_tool import ( + extract_using_readabilipy, + get_image_upload_file_ids, + get_url, + page_result, +) +class FakeResponse: + """Minimal fake response object for ssrf_proxy / cloudscraper.""" + + def __init__(self, *, status_code=200, headers=None, content=b"", text=""): + self.status_code = status_code + self.headers = headers or {} + self.content = content + self.text = text if text else content.decode("utf-8", errors="ignore") + + +# --------------------------- +# Tests: page_result +# --------------------------- +@pytest.mark.parametrize( + ("text", "cursor", "maxlen", "expected"), + [ + ("abcdef", 0, 3, "abc"), + ("abcdef", 2, 10, "cdef"), # maxlen beyond end + ("abcdef", 6, 5, ""), # cursor at end + ("abcdef", 7, 5, ""), # cursor beyond end + ("", 0, 5, ""), # empty text + ], +) +def test_page_result(text, cursor, maxlen, expected): + assert page_result(text, cursor, maxlen) == expected + + +# --------------------------- +# Tests: get_url +# --------------------------- +@pytest.fixture +def stub_support_types(monkeypatch): + """Stub supported content types list.""" + import core.tools.utils.web_reader_tool as mod + + # e.g. binary types supported by ExtractProcessor + monkeypatch.setattr(mod.extract_processor, "SUPPORT_URL_CONTENT_TYPES", ["application/pdf", "text/plain"]) + return mod + + +def test_get_url_unsupported_content_type(monkeypatch, stub_support_types): + # HEAD 200 but content-type not supported and not text/html + def fake_head(url, headers=None, follow_redirects=True, timeout=None): + return FakeResponse( + status_code=200, + headers={"Content-Type": "image/png"}, # not supported + ) + + monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head) + + result = get_url("https://x.test/file.png") + assert result == "Unsupported content-type [image/png] of URL." + + +def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch, stub_support_types): + """ + When content-type is in SUPPORT_URL_CONTENT_TYPES, + should call ExtractProcessor.load_from_url and return its text. + """ + calls = {"load": 0} + + def fake_head(url, headers=None, follow_redirects=True, timeout=None): + return FakeResponse( + status_code=200, + headers={"Content-Type": "application/pdf"}, + ) + + def fake_load_from_url(url, return_text=False): + calls["load"] += 1 + assert return_text is True + return "PDF extracted text" + + monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head) + monkeypatch.setattr(stub_support_types.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url)) + + result = get_url("https://x.test/doc.pdf") + assert calls["load"] == 1 + assert result == "PDF extracted text" + + +def test_get_url_html_flow_with_chardet_and_readability(monkeypatch, stub_support_types): + """200 + text/html → GET, chardet detects encoding, readability returns article which is templated.""" + + def fake_head(url, headers=None, follow_redirects=True, timeout=None): + return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}) + + def fake_get(url, headers=None, follow_redirects=True, timeout=None): + html = b"