diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 7f7787b92a..23a877b7e3 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -72,6 +72,11 @@ class ApiProviderControllerItem(TypedDict): controller: ApiToolProviderController +class EmojiIconDict(TypedDict): + background: str + content: str + + class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} @@ -916,7 +921,7 @@ class ToolManager: ) @classmethod - def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]: + def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: workflow_provider: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) @@ -933,7 +938,7 @@ class ToolManager: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]: + def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: api_provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) @@ -950,7 +955,7 @@ class ToolManager: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str: + def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str: try: with Session(db.engine) as session: mcp_service = MCPToolManageService(session=session) @@ -970,7 +975,7 @@ class ToolManager: tenant_id: str, provider_type: ToolProviderType, provider_id: str, - ) -> str | Mapping[str, str]: + ) -> str | EmojiIconDict | dict[str, str]: """ get the tool icon diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 3dbbbe6563..c2b520fa99 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,5 +1,4 @@ import threading -from typing import Any from flask import Flask, current_app from pydantic import BaseModel, Field @@ -13,11 +12,12 @@ from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -default_retrieval_model: dict[str, Any] = { +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 057ec41f65..2969fafe89 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import NotRequired, TypedDict, cast from pydantic import BaseModel, Field from sqlalchemy import select @@ -16,7 +16,19 @@ from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model: dict[str, Any] = { + +class DefaultRetrievalModelDict(TypedDict): + search_method: RetrievalMethod + reranking_enable: bool + reranking_model: dict[str, str] + reranking_mode: NotRequired[str] + weights: NotRequired[dict[str, object] | None] + score_threshold: NotRequired[float] + top_k: int + score_threshold_enabled: bool + + +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -125,7 +137,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if metadata_condition and not document_ids_filter: return "" # get retrieval model , if the model is not setting , using default - retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model + retrieval_model = dataset.retrieval_model or default_retrieval_model retrieval_resource_list: list[RetrievalSourceMetadata] = [] if dataset.indexing_technique == "economy": # use keyword table query diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index fc2b41d960..f7484b93fb 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,4 +1,5 @@ import re +from collections.abc import Mapping from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError @@ -20,10 +21,18 @@ class InterfaceDict(TypedDict): operation: dict[str, Any] +class OpenAPISpecDict(TypedDict): + openapi: str + info: dict[str, str] + servers: list[dict[str, Any]] + paths: dict[str, Any] + components: dict[str, Any] + + class ApiBasedToolSchemaParser: @staticmethod def parse_openapi_to_tool_bundle( - openapi: dict, extra_info: dict | None = None, warning: dict | None = None + openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None ) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -277,7 +286,7 @@ class ApiBasedToolSchemaParser: @staticmethod def parse_swagger_to_openapi( swagger: dict, extra_info: dict | None = None, warning: dict | None = None - ) -> dict[str, Any]: + ) -> OpenAPISpecDict: warning = warning or {} """ parse swagger to openapi @@ -293,7 +302,7 @@ class ApiBasedToolSchemaParser: if len(servers) == 0: raise ToolApiSchemaError("No server found in the swagger yaml.") - converted_openapi: dict[str, Any] = { + converted_openapi: OpenAPISpecDict = { "openapi": "3.0.0", "info": { "title": info.get("title", "Swagger"),