mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 12:37:20 +08:00
r2
This commit is contained in:
parent
7f59ffe7af
commit
797d044714
@ -462,18 +462,6 @@ class PublishedRagPipelineApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("knowledge_base_setting", type=dict, location="json", help="Invalid knowledge base setting.")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if not args.get("knowledge_base_setting"):
|
|
||||||
raise ValueError("Missing knowledge base setting.")
|
|
||||||
|
|
||||||
knowledge_base_setting_data = args.get("knowledge_base_setting")
|
|
||||||
if not knowledge_base_setting_data:
|
|
||||||
raise ValueError("Missing knowledge base setting.")
|
|
||||||
|
|
||||||
knowledge_base_setting = KnowledgeBaseUpdateConfiguration(**knowledge_base_setting_data)
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
pipeline = session.merge(pipeline)
|
pipeline = session.merge(pipeline)
|
||||||
@ -481,7 +469,6 @@ class PublishedRagPipelineApi(Resource):
|
|||||||
session=session,
|
session=session,
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
account=current_user,
|
account=current_user,
|
||||||
knowledge_base_setting=knowledge_base_setting,
|
|
||||||
)
|
)
|
||||||
pipeline.is_published = True
|
pipeline.is_published = True
|
||||||
pipeline.workflow_id = workflow.id
|
pipeline.workflow_id = workflow.id
|
||||||
|
|||||||
@ -22,11 +22,12 @@ class PluginDatasourceManager(BasePluginClient):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def transformer(json_response: dict[str, Any]) -> dict:
|
def transformer(json_response: dict[str, Any]) -> dict:
|
||||||
for provider in json_response.get("data", []):
|
if json_response.get("data"):
|
||||||
declaration = provider.get("declaration", {}) or {}
|
for provider in json_response.get("data", []):
|
||||||
provider_name = declaration.get("identity", {}).get("name")
|
declaration = provider.get("declaration", {}) or {}
|
||||||
for datasource in declaration.get("datasources", []):
|
provider_name = declaration.get("identity", {}).get("name")
|
||||||
datasource["identity"]["provider"] = provider_name
|
for datasource in declaration.get("datasources", []):
|
||||||
|
datasource["identity"]["provider"] = provider_name
|
||||||
|
|
||||||
return json_response
|
return json_response
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from core.datasource.entities.datasource_entities import (
|
|||||||
)
|
)
|
||||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||||
from core.file import File
|
from core.file import File
|
||||||
|
from core.file.enums import FileTransferMethod, FileType
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
from core.variables.segments import ArrayAnySegment, FileSegment
|
from core.variables.segments import ArrayAnySegment, FileSegment
|
||||||
from core.variables.variables import ArrayAnyVariable
|
from core.variables.variables import ArrayAnyVariable
|
||||||
@ -118,7 +119,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
case DatasourceProviderType.LOCAL_FILE:
|
case DatasourceProviderType.LOCAL_FILE:
|
||||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first()
|
related_id = datasource_info.get("related_id")
|
||||||
|
if not related_id:
|
||||||
|
raise DatasourceNodeError(
|
||||||
|
"File is not exist"
|
||||||
|
)
|
||||||
|
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
|
||||||
if not upload_file:
|
if not upload_file:
|
||||||
raise ValueError("Invalid upload file Info")
|
raise ValueError("Invalid upload file Info")
|
||||||
|
|
||||||
@ -128,14 +134,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
extension="." + upload_file.extension,
|
extension="." + upload_file.extension,
|
||||||
mime_type=upload_file.mime_type,
|
mime_type=upload_file.mime_type,
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
type=datasource_info.get("type", ""),
|
type=FileType.CUSTOM,
|
||||||
transfer_method=datasource_info.get("transfer_method", ""),
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
remote_url=upload_file.source_url,
|
remote_url=upload_file.source_url,
|
||||||
related_id=upload_file.id,
|
related_id=upload_file.id,
|
||||||
size=upload_file.size,
|
size=upload_file.size,
|
||||||
storage_key=upload_file.key,
|
storage_key=upload_file.key,
|
||||||
)
|
)
|
||||||
variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)])
|
variable_pool.add([self.node_id, "file"], [file_info])
|
||||||
for key, value in datasource_info.items():
|
for key, value in datasource_info.items():
|
||||||
# construct new key list
|
# construct new key list
|
||||||
new_key_list = ["file", key]
|
new_key_list = ["file", key]
|
||||||
@ -147,7 +153,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
inputs=parameters_for_log,
|
inputs=parameters_for_log,
|
||||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
outputs={
|
outputs={
|
||||||
"file_info": file_info,
|
"file_info": datasource_info,
|
||||||
"datasource_type": datasource_type,
|
"datasource_type": datasource_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -220,7 +226,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||||
return list(variable.value) if variable else []
|
return list(variable.value) if variable else []
|
||||||
|
|
||||||
|
|
||||||
def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -53,6 +53,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
|||||||
)
|
)
|
||||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||||
KnowledgeBaseUpdateConfiguration,
|
KnowledgeBaseUpdateConfiguration,
|
||||||
|
KnowledgeConfiguration,
|
||||||
RagPipelineDatasetCreateEntity,
|
RagPipelineDatasetCreateEntity,
|
||||||
)
|
)
|
||||||
from services.errors.account import InvalidActionError, NoPermissionError
|
from services.errors.account import InvalidActionError, NoPermissionError
|
||||||
@ -495,11 +496,11 @@ class DatasetService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def update_rag_pipeline_dataset_settings(session: Session,
|
def update_rag_pipeline_dataset_settings(session: Session,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
|
knowledge_configuration: KnowledgeConfiguration,
|
||||||
has_published: bool = False):
|
has_published: bool = False):
|
||||||
if not has_published:
|
if not has_published:
|
||||||
dataset.chunk_structure = knowledge_base_setting.chunk_structure
|
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||||
index_method = knowledge_base_setting.index_method
|
index_method = knowledge_configuration.index_method
|
||||||
dataset.indexing_technique = index_method.indexing_technique
|
dataset.indexing_technique = index_method.indexing_technique
|
||||||
if index_method == "high_quality":
|
if index_method == "high_quality":
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
@ -519,26 +520,26 @@ class DatasetService:
|
|||||||
dataset.keyword_number = index_method.economy_setting.keyword_number
|
dataset.keyword_number = index_method.economy_setting.keyword_number
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid index method")
|
raise ValueError("Invalid index method")
|
||||||
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
|
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
|
||||||
session.add(dataset)
|
session.add(dataset)
|
||||||
else:
|
else:
|
||||||
if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure:
|
if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
|
||||||
raise ValueError("Chunk structure is not allowed to be updated.")
|
raise ValueError("Chunk structure is not allowed to be updated.")
|
||||||
action = None
|
action = None
|
||||||
if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique:
|
if dataset.indexing_technique != knowledge_configuration.index_method.indexing_technique:
|
||||||
# if update indexing_technique
|
# if update indexing_technique
|
||||||
if knowledge_base_setting.index_method.indexing_technique == "economy":
|
if knowledge_configuration.index_method.indexing_technique == "economy":
|
||||||
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
|
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
|
||||||
elif knowledge_base_setting.index_method.indexing_technique == "high_quality":
|
elif knowledge_configuration.index_method.indexing_technique == "high_quality":
|
||||||
action = "add"
|
action = "add"
|
||||||
# get embedding model setting
|
# get embedding model setting
|
||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
embedding_model = model_manager.get_model_instance(
|
embedding_model = model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
|
provider=knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
|
model=knowledge_configuration.index_method.embedding_setting.embedding_model_name,
|
||||||
)
|
)
|
||||||
dataset.embedding_model = embedding_model.model
|
dataset.embedding_model = embedding_model.model
|
||||||
dataset.embedding_model_provider = embedding_model.provider
|
dataset.embedding_model_provider = embedding_model.provider
|
||||||
@ -607,9 +608,9 @@ class DatasetService:
|
|||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ValueError(ex.description)
|
raise ValueError(ex.description)
|
||||||
elif dataset.indexing_technique == "economy":
|
elif dataset.indexing_technique == "economy":
|
||||||
if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number:
|
if dataset.keyword_number != knowledge_configuration.index_method.economy_setting.keyword_number:
|
||||||
dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number
|
dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number
|
||||||
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
|
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
|
||||||
session.add(dataset)
|
session.add(dataset)
|
||||||
session.commit()
|
session.commit()
|
||||||
if action:
|
if action:
|
||||||
|
|||||||
@ -47,7 +47,7 @@ from models.workflow import (
|
|||||||
WorkflowType,
|
WorkflowType,
|
||||||
)
|
)
|
||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity
|
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, PipelineTemplateInfoEntity
|
||||||
from services.errors.app import WorkflowHashNotEqualError
|
from services.errors.app import WorkflowHashNotEqualError
|
||||||
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
||||||
|
|
||||||
@ -262,7 +262,6 @@ class RagPipelineService:
|
|||||||
session: Session,
|
session: Session,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
account: Account,
|
account: Account,
|
||||||
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
|
|
||||||
) -> Workflow:
|
) -> Workflow:
|
||||||
draft_workflow_stmt = select(Workflow).where(
|
draft_workflow_stmt = select(Workflow).where(
|
||||||
Workflow.tenant_id == pipeline.tenant_id,
|
Workflow.tenant_id == pipeline.tenant_id,
|
||||||
@ -291,16 +290,23 @@ class RagPipelineService:
|
|||||||
# commit db session changes
|
# commit db session changes
|
||||||
session.add(workflow)
|
session.add(workflow)
|
||||||
|
|
||||||
# update dataset
|
graph = workflow.graph_dict
|
||||||
dataset = pipeline.dataset
|
nodes = graph.get("nodes", [])
|
||||||
if not dataset:
|
for node in nodes:
|
||||||
raise ValueError("Dataset not found")
|
if node.get("data", {}).get("type") == "knowledge_index":
|
||||||
DatasetService.update_rag_pipeline_dataset_settings(
|
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
|
||||||
session=session,
|
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
|
||||||
dataset=dataset,
|
|
||||||
knowledge_base_setting=knowledge_base_setting,
|
# update dataset
|
||||||
has_published=pipeline.is_published
|
dataset = pipeline.dataset
|
||||||
)
|
if not dataset:
|
||||||
|
raise ValueError("Dataset not found")
|
||||||
|
DatasetService.update_rag_pipeline_dataset_settings(
|
||||||
|
session=session,
|
||||||
|
dataset=dataset,
|
||||||
|
knowledge_configuration=knowledge_configuration,
|
||||||
|
has_published=pipeline.is_published
|
||||||
|
)
|
||||||
# return new workflow
|
# return new workflow
|
||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user