diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index cdbce3ce14..37ed5b8385 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -10,7 +10,7 @@ class RerankingModelConfig(BaseModel): Reranking Model Config. """ provider: str - mode: str + model: str class MultipleRetrievalConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 87ba4239f8..0534695adb 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -5,11 +5,12 @@ from flask import Flask, current_app from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import ModelType, ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.rag.datasource.retrieval_service import RetrievalService from core.rerank.rerank import RerankRunner @@ -192,6 +193,25 @@ class KnowledgeRetrievalNode(BaseNode): tools.append(message_tool) # fetch model config model_instance, model_config = self._fetch_model_config(node_data) + # check model is support tool calling + model_type_instance = model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + # get model schema + model_schema = model_type_instance.get_model_schema( + model=model_config.model, + credentials=model_config.credentials + ) + + if not model_schema: + return None + planning_strategy = PlanningStrategy.REACT_ROUTER + features = model_schema.features + if features: + if ModelFeature.TOOL_CALL in features \ + or ModelFeature.MULTI_TOOL_CALL in features: + planning_strategy = PlanningStrategy.ROUTER + + prompt_messages = [ SystemPromptMessage(content='You are a helpful AI assistant.'), UserPromptMessage(content=query) @@ -328,7 +348,7 @@ class KnowledgeRetrievalNode(BaseNode): tenant_id=self.tenant_id, provider=node_data.multiple_retrieval_config.reranking_model.provider, model_type=ModelType.RERANK, - model=node_data.multiple_retrieval_config.reranking_model.name + model=node_data.multiple_retrieval_config.reranking_model.model ) rerank_runner = RerankRunner(rerank_model_instance) @@ -374,3 +394,4 @@ class KnowledgeRetrievalNode(BaseNode): ) all_documents.extend(documents) +