knowledge fix

This commit is contained in:
jyong 2024-03-18 20:35:10 +08:00
parent e66c55ba9e
commit a4f367b8ff
2 changed files with 24 additions and 3 deletions

View File

@ -10,7 +10,7 @@ class RerankingModelConfig(BaseModel):
Reranking Model Config.
"""
provider: str
mode: str
model: str
class MultipleRetrievalConfig(BaseModel):

View File

@ -5,11 +5,12 @@ from flask import Flask, current_app
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.model_entities import ModelType, ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.datasource.retrieval_service import RetrievalService
from core.rerank.rerank import RerankRunner
@ -192,6 +193,25 @@ class KnowledgeRetrievalNode(BaseNode):
tools.append(message_tool)
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data)
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model,
credentials=model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
prompt_messages = [
SystemPromptMessage(content='You are a helpful AI assistant.'),
UserPromptMessage(content=query)
@ -328,7 +348,7 @@ class KnowledgeRetrievalNode(BaseNode):
tenant_id=self.tenant_id,
provider=node_data.multiple_retrieval_config.reranking_model.provider,
model_type=ModelType.RERANK,
model=node_data.multiple_retrieval_config.reranking_model.name
model=node_data.multiple_retrieval_config.reranking_model.model
)
rerank_runner = RerankRunner(rerank_model_instance)
@ -374,3 +394,4 @@ class KnowledgeRetrievalNode(BaseNode):
)
all_documents.extend(documents)