diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index 1a4e9240b6..f502157eda 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -15,6 +15,7 @@ from libs.login import login_required from models.dataset import DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService def _validate_name(name): @@ -91,7 +92,7 @@ class CreateRagPipelineDatasetApi(Resource): raise Forbidden() rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args) try: - import_info = DatasetService.create_rag_pipeline_dataset( + import_info = RagPipelineDslService.create_rag_pipeline_dataset( tenant_id=current_user.current_tenant_id, rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, ) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index bbeaa33341..fc6eab529a 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -40,6 +40,7 @@ from libs.login import current_user, login_required from models.account import Account from models.dataset import Pipeline from models.model import EndUser +from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService @@ -282,15 +283,18 @@ class PublishedRagPipelineRunApi(Resource): parser.add_argument("datasource_info_list", type=list, required=True, location="json") parser.add_argument("start_node_id", type=str, required=True, location="json") parser.add_argument("is_preview", type=bool, required=True, location="json", default=False) + parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming") args = parser.parse_args() + streaming = args["response_mode"] == "streaming" + try: response = PipelineGenerateService.generate( pipeline=pipeline, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED, - streaming=True, + streaming=streaming, ) return helper.compact_generate_response(response) @@ -459,16 +463,17 @@ class PublishedRagPipelineApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument("marked_name", type=str, required=False, default="", location="json") - parser.add_argument("marked_comment", type=str, required=False, default="", location="json") + parser.add_argument("knowledge_base_setting", type=dict, location="json", help="Invalid knowledge base setting.") args = parser.parse_args() - # Validate name and comment length - if args.marked_name and len(args.marked_name) > 20: - raise ValueError("Marked name cannot exceed 20 characters") - if args.marked_comment and len(args.marked_comment) > 100: - raise ValueError("Marked comment cannot exceed 100 characters") + if not args.get("knowledge_base_setting"): + raise ValueError("Missing knowledge base setting.") + knowledge_base_setting_data = args.get("knowledge_base_setting") + if not knowledge_base_setting_data: + raise ValueError("Missing knowledge base setting.") + + knowledge_base_setting = KnowledgeBaseUpdateConfiguration(**knowledge_base_setting_data) rag_pipeline_service = RagPipelineService() with Session(db.engine) as session: pipeline = session.merge(pipeline) @@ -476,8 +481,7 @@ class PublishedRagPipelineApi(Resource): session=session, pipeline=pipeline, account=current_user, - marked_name=args.marked_name or "", - marked_comment=args.marked_comment or "", + knowledge_base_setting=knowledge_base_setting, ) pipeline.is_published = True pipeline.workflow_id = workflow.id diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 9c25f8f4e6..e4c96775c8 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -28,10 +28,13 @@ from core.app.entities.task_entities import WorkflowAppBlockingResponse, Workflo from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.index_processor.constant.built_in_field import BuiltInField from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, Pipeline +from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode from services.dataset_service import DocumentService @@ -51,7 +54,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[True], call_depth: int, workflow_thread_pool_id: Optional[str], - ) -> Generator[Mapping | str, None, None]: ... + ) -> Generator[Mapping | str, None, None] | None: ... @overload def generate( @@ -92,7 +95,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, - ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: # convert to app config pipeline_config = PipelineConfigManager.get_pipeline_config( pipeline=pipeline, @@ -119,14 +122,14 @@ class PipelineGenerator(BaseAppGenerator): document = self._build_document( tenant_id=pipeline.tenant_id, dataset_id=dataset.id, - built_in_field_enabled=pipeline.dataset.built_in_field_enabled, + built_in_field_enabled=dataset.built_in_field_enabled, datasource_type=datasource_type, datasource_info=datasource_info, created_from="rag-pipeline", position=position, account=user, batch=batch, - document_form=pipeline.dataset.chunk_structure, + document_form=dataset.chunk_structure, ) db.session.add(document) db.session.commit() @@ -138,7 +141,7 @@ class PipelineGenerator(BaseAppGenerator): pipeline_config=pipeline_config, datasource_type=datasource_type, datasource_info=datasource_info, - dataset_id=pipeline.dataset.id, + dataset_id=dataset.id, start_node_id=start_node_id, batch=batch, document_id=document_id, @@ -159,15 +162,24 @@ class PipelineGenerator(BaseAppGenerator): contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - + if invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, ) if invoke_from == InvokeFrom.DEBUGGER: return self._generate( @@ -176,6 +188,7 @@ class PipelineGenerator(BaseAppGenerator): user=user, application_generate_entity=application_generate_entity, invoke_from=invoke_from, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, workflow_thread_pool_id=workflow_thread_pool_id, @@ -187,6 +200,7 @@ class PipelineGenerator(BaseAppGenerator): user=user, application_generate_entity=application_generate_entity, invoke_from=invoke_from, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, workflow_thread_pool_id=workflow_thread_pool_id, @@ -200,6 +214,7 @@ class PipelineGenerator(BaseAppGenerator): user: Union[Account, EndUser], application_generate_entity: RagPipelineGenerateEntity, invoke_from: InvokeFrom, + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, workflow_thread_pool_id: Optional[str] = None, @@ -207,11 +222,12 @@ class PipelineGenerator(BaseAppGenerator): """ Generate App response. - :param app_model: App + :param pipeline: Pipeline :param workflow: Workflow :param user: account or end user :param application_generate_entity: application generate entity :param invoke_from: invoke from source + :param workflow_execution_repository: repository for workflow execution :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream :param workflow_thread_pool_id: workflow thread pool id @@ -244,6 +260,7 @@ class PipelineGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, stream=streaming, ) @@ -276,16 +293,20 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("inputs is required") # convert to app config - app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) + pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) + + dataset = pipeline.dataset + if not dataset: + raise ValueError("Pipeline dataset is required") # init application generate entity - use RagPipelineGenerateEntity instead application_generate_entity = RagPipelineGenerateEntity( task_id=str(uuid.uuid4()), - app_config=app_config, - pipeline_config=app_config, + app_config=pipeline_config, + pipeline_config=pipeline_config, datasource_type=args.get("datasource_type", ""), datasource_info=args.get("datasource_info", {}), - dataset_id=pipeline.dataset_id, + dataset_id=dataset.id, batch=args.get("batch", ""), document_id=args.get("document_id"), inputs={}, @@ -299,10 +320,16 @@ class PipelineGenerator(BaseAppGenerator): contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING, + ) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, user=user, @@ -316,6 +343,7 @@ class PipelineGenerator(BaseAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, ) @@ -345,20 +373,30 @@ class PipelineGenerator(BaseAppGenerator): if args.get("inputs") is None: raise ValueError("inputs is required") + dataset = pipeline.dataset + if not dataset: + raise ValueError("Pipeline dataset is required") + # convert to app config - app_config = WorkflowAppConfigManager.get_app_config(pipeline=pipeline, workflow=workflow) + pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) # init application generate entity - application_generate_entity = WorkflowAppGenerateEntity( + application_generate_entity = RagPipelineGenerateEntity( task_id=str(uuid.uuid4()), - app_config=app_config, + app_config=pipeline_config, + pipeline_config=pipeline_config, + datasource_type=args.get("datasource_type", ""), + datasource_info=args.get("datasource_info", {}), + batch=args.get("batch", ""), + document_id=args.get("document_id"), + dataset_id=dataset.id, inputs={}, files=[], user_id=user.id, stream=streaming, invoke_from=InvokeFrom.DEBUGGER, extras={"auto_generate_conversation_name": False}, - single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), + single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), workflow_run_id=str(uuid.uuid4()), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -368,6 +406,13 @@ class PipelineGenerator(BaseAppGenerator): # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING, + ) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, user=user, @@ -381,6 +426,7 @@ class PipelineGenerator(BaseAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, ) @@ -438,6 +484,7 @@ class PipelineGenerator(BaseAppGenerator): workflow: Workflow, queue_manager: AppQueueManager, user: Union[Account, EndUser], + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, stream: bool = False, ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -459,6 +506,7 @@ class PipelineGenerator(BaseAppGenerator): user=user, stream=stream, workflow_node_execution_repository=workflow_node_execution_repository, + workflow_execution_repository=workflow_execution_repository, ) try: @@ -481,7 +529,7 @@ class PipelineGenerator(BaseAppGenerator): datasource_info: Mapping[str, Any], created_from: str, position: int, - account: Account, + account: Union[Account, EndUser], batch: str, document_form: str, ): diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 23dbfef70d..8d90e7ee3e 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, RagPipelineGenerateEntity, ) +from core.variables.variables import RAGPipelineVariable from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey @@ -106,12 +107,19 @@ class PipelineRunner(WorkflowBasedAppRunner): SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info, SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value, } + rag_pipeline_variables = {} + if workflow.rag_pipeline_variables: + for v in workflow.rag_pipeline_variables: + rag_pipeline_variable = RAGPipelineVariable(**v) + if rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id and rag_pipeline_variable.variable in inputs: + rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable] variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, environment_variables=workflow.environment_variables, conversation_variables=[], + rag_pipeline_variables=rag_pipeline_variables, ) # init graph diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index 045ca64872..bae39dc8c7 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -9,10 +9,10 @@ from core.tools.errors import ToolProviderCredentialValidationError class DatasourcePluginProviderController(ABC): - entity: DatasourceProviderEntityWithPlugin | None + entity: DatasourceProviderEntityWithPlugin tenant_id: str - def __init__(self, entity: DatasourceProviderEntityWithPlugin | None, tenant_id: str) -> None: + def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None: self.entity = entity self.tenant_id = tenant_id diff --git a/api/core/datasource/__base/datasource_runtime.py b/api/core/datasource/__base/datasource_runtime.py index 51ff1fc6c1..9ddc25a637 100644 --- a/api/core/datasource/__base/datasource_runtime.py +++ b/api/core/datasource/__base/datasource_runtime.py @@ -14,9 +14,9 @@ class DatasourceRuntime(BaseModel): """ tenant_id: str - tool_id: Optional[str] = None + datasource_id: Optional[str] = None invoke_from: Optional[InvokeFrom] = None - tool_invoke_from: Optional[ToolInvokeFrom] = None + datasource_invoke_from: Optional[DatasourceInvokeFrom] = None credentials: dict[str, Any] = Field(default_factory=dict) runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index 11168b4c26..8c0f20ce2d 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -11,7 +11,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon def __init__( self, - entity: DatasourceProviderEntityWithPlugin | None, + entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str, diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index b5212eb719..7847218bb9 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -30,22 +30,16 @@ class PluginDatasourceManager(BasePluginClient): return json_response - # response = self._request_with_plugin_daemon_response( - # "GET", - # f"plugin/{tenant_id}/management/datasources", - # list[PluginDatasourceProviderEntity], - # params={"page": 1, "page_size": 256}, - # transformer=transformer, - # ) + response = self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/datasources", + list[PluginDatasourceProviderEntity], + params={"page": 1, "page_size": 256}, + transformer=transformer, + ) + local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) - # for provider in response: - # provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" - - # # override the provider name for each tool to plugin_id/provider_name - # for datasource in provider.declaration.datasources: - # datasource.identity.provider = provider.declaration.identity.name - - return [PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())] + return [local_file_datasource_provider] + response def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: """ diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 50511de16f..72e4923b58 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -13,7 +13,8 @@ from core.rag.splitter.fixed_text_splitter import ( FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter -from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule +from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Document as DatasetDocument class BaseIndexProcessor(ABC): @@ -37,6 +38,10 @@ class BaseIndexProcessor(ABC): @abstractmethod def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): raise NotImplementedError + + @abstractmethod + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + raise NotImplementedError @abstractmethod def retrieve( diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 5eab77d4f8..559bc5d59b 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -131,7 +131,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): paragraph = GeneralStructureChunk(**chunks) documents = [] - for content in paragraph.general_chunk: + for content in paragraph.general_chunks: metadata = { "dataset_id": dataset.id, "document_id": document.id, @@ -151,3 +151,14 @@ class ParagraphIndexProcessor(BaseIndexProcessor): elif dataset.indexing_technique == "economy": keyword = Keyword(dataset) keyword.add_texts(documents) + + + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + paragraph = GeneralStructureChunk(**chunks) + preview = [] + for content in paragraph.general_chunks: + preview.append({"content": content}) + return { + "preview": preview, + "total_segments": len(paragraph.general_chunks) + } \ No newline at end of file diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 6300d05707..7a3f8f1c63 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -234,3 +234,19 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) + + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + parent_childs = ParentChildStructureChunk(**chunks) + preview = [] + for parent_child in parent_childs.parent_child_chunks: + preview.append( + { + "content": parent_child.parent_content, + "child_chunks": parent_child.child_contents + + } + ) + return { + "preview": preview, + "total_segments": len(parent_childs.parent_child_chunks) + } \ No newline at end of file diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 0055625e13..b415596254 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -4,7 +4,7 @@ import logging import re import threading import uuid -from typing import Optional +from typing import Any, Mapping, Optional import pandas as pd from flask import Flask, current_app @@ -20,7 +20,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import Document from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper -from models.dataset import Dataset +from models.dataset import Dataset, Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule @@ -160,6 +160,12 @@ class QAIndexProcessor(BaseIndexProcessor): doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) return docs + + def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): + pass + + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + return {"preview": chunks} def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): format_documents = [] diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 52795bbadf..9f0054a165 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -40,7 +40,7 @@ class GeneralStructureChunk(BaseModel): General Structure Chunk. """ - general_chunk: list[str] + general_chunks: list[str] class ParentChildChunk(BaseModel): diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index b650b1682e..c0952383a9 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from typing import cast from uuid import uuid4 -from pydantic import Field +from pydantic import BaseModel, Field from core.helper import encrypter @@ -93,3 +93,20 @@ class FileVariable(FileSegment, Variable): class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass + +class RAGPipelineVariable(BaseModel): + belong_to_node_id: str = Field(description="belong to which node id, shared means public") + type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") + label: str = Field(description="label") + description: str | None = Field(description="description", default="") + variable: str = Field(description="variable key", default="") + max_length: int | None = Field(description="max length, applicable to text-input, paragraph, and file-list", default=0) + default_value: str | None = Field(description="default value", default="") + placeholder: str | None = Field(description="placeholder", default="") + unit: str | None = Field(description="unit, applicable to Number", default="") + tooltips: str | None = Field(description="helpful text", default="") + allowed_file_types: list[str] | None = Field(description="image, document, audio, video, custom.", default_factory=list) + allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) + allowed_file_upload_methods: list[str] | None = Field(description="remote_url, local_file, tool_file.", default_factory=list) + required: bool = Field(description="optional, default false", default=False) + options: list[str] | None = Field(default_factory=list) diff --git a/api/core/workflow/constants.py b/api/core/workflow/constants.py index 59edcee456..7664be0983 100644 --- a/api/core/workflow/constants.py +++ b/api/core/workflow/constants.py @@ -1,4 +1,4 @@ SYSTEM_VARIABLE_NODE_ID = "sys" ENVIRONMENT_VARIABLE_NODE_ID = "env" CONVERSATION_VARIABLE_NODE_ID = "conversation" -PIPELINE_VARIABLE_NODE_ID = "pipeline" +RAG_PIPELINE_VARIABLE_NODE_ID = "rag" diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index af26864c01..319833145e 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -10,7 +10,12 @@ from core.variables import Segment, SegmentGroup, Variable from core.variables.segments import FileSegment, NoneSegment from factories import variable_factory -from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from ..constants import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) from ..enums import SystemVariableKey VariableValue = Union[str, int, float, dict, list, File] @@ -42,6 +47,10 @@ class VariablePool(BaseModel): description="Conversation variables.", default_factory=list, ) + rag_pipeline_variables: Mapping[str, Any] = Field( + description="RAG pipeline variables.", + default_factory=dict, + ) def __init__( self, @@ -50,18 +59,21 @@ class VariablePool(BaseModel): user_inputs: Mapping[str, Any] | None = None, environment_variables: Sequence[Variable] | None = None, conversation_variables: Sequence[Variable] | None = None, + rag_pipeline_variables: Mapping[str, Any] | None = None, **kwargs, ): environment_variables = environment_variables or [] conversation_variables = conversation_variables or [] user_inputs = user_inputs or {} system_variables = system_variables or {} + rag_pipeline_variables = rag_pipeline_variables or {} super().__init__( system_variables=system_variables, user_inputs=user_inputs, environment_variables=environment_variables, conversation_variables=conversation_variables, + rag_pipeline_variables=rag_pipeline_variables, **kwargs, ) @@ -73,6 +85,9 @@ class VariablePool(BaseModel): # Add conversation variables to the variable pool for var in self.conversation_variables: self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) + # Add rag pipeline variables to the variable pool + for var, value in self.rag_pipeline_variables.items(): + self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var), value) def add(self, selector: Sequence[str], value: Any, /) -> None: """ diff --git a/api/core/workflow/entities/workflow_execution_entities.py b/api/core/workflow/entities/workflow_execution_entities.py index 200d4697b5..28fae53ced 100644 --- a/api/core/workflow/entities/workflow_execution_entities.py +++ b/api/core/workflow/entities/workflow_execution_entities.py @@ -20,6 +20,7 @@ class WorkflowType(StrEnum): WORKFLOW = "workflow" CHAT = "chat" + RAG_PIPELINE = "rag-pipeline" class WorkflowExecutionStatus(StrEnum): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 36273d8ec1..c17f1eeb2b 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -173,7 +173,7 @@ class GraphEngine: ) return elif isinstance(item, NodeRunSucceededEvent): - if item.node_type == NodeType.END: + if item.node_type in (NodeType.END, NodeType.KNOWLEDGE_INDEX): self.graph_runtime_state.outputs = ( dict(item.route_node_state.node_run_result.outputs) if item.route_node_state.node_run_result @@ -319,7 +319,7 @@ class GraphEngine: # It may not be necessary, but it is necessary. :) if ( self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() - == NodeType.END.value + in [NodeType.END.value, NodeType.KNOWLEDGE_INDEX.value] ): break diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 2f15fba3af..8f841f9564 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -10,14 +10,16 @@ from core.datasource.entities.datasource_entities import ( from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.file import File from core.plugin.impl.exc import PluginDaemonClientSideError -from core.variables.segments import ArrayAnySegment +from core.variables.segments import ArrayAnySegment, FileSegment from core.variables.variables import ArrayAnyVariable from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser +from extensions.ext_database import db +from models.model import UploadFile from models.workflow import WorkflowNodeExecutionStatus from .entities import DatasourceNodeData @@ -59,7 +61,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): provider_id=node_data.provider_id, datasource_name=node_data.datasource_name or "", tenant_id=self.tenant_id, - datasource_type=DatasourceProviderType(datasource_type), + datasource_type=DatasourceProviderType.value_of(datasource_type), ) except DatasourceNodeError as e: return NodeRunResult( @@ -69,7 +71,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): error=f"Failed to get datasource runtime: {str(e)}", error_type=type(e).__name__, ) - + # get parameters datasource_parameters = datasource_runtime.entity.parameters @@ -105,7 +107,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): "datasource_type": datasource_type, }, ) - case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE: + case DatasourceProviderType.WEBSITE_CRAWL: return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, @@ -116,18 +118,42 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): }, ) case DatasourceProviderType.LOCAL_FILE: + upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first() + if not upload_file: + raise ValueError("Invalid upload file Info") + + file_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=self.tenant_id, + type=datasource_info.get("type", ""), + transfer_method=datasource_info.get("transfer_method", ""), + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + ) + variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)]) + for key, value in datasource_info.items(): + # construct new key list + new_key_list = ["file", key] + self._append_variables_recursively( + variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ - "file": datasource_info, - "datasource_type": datasource_runtime.datasource_provider_type, + "file_info": file_info, + "datasource_type": datasource_type, }, ) case _: raise DatasourceNodeError( - f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}" + f"Unsupported datasource provider: {datasource_type}" ) except PluginDaemonClientSideError as e: return NodeRunResult( @@ -194,6 +220,26 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] + + + def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue): + """ + Append variables recursively + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + variable_pool.add([node_id] + variable_key_list, variable_value) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, dict): + for key, value in variable_value.items(): + # construct new key list + new_key_list = variable_key_list + [key] + self._append_variables_recursively( + variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value + ) @classmethod def _extract_variable_selector_to_variable_mapping( diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 1f414ad0e2..dee3c1d2fb 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -18,7 +18,7 @@ class DatasourceEntity(BaseModel): class DatasourceNodeData(BaseNodeData, DatasourceEntity): class DatasourceInput(BaseModel): # TODO: check this type - value: Optional[Union[Any, list[str]]] = None + value: Union[Any, list[str]] type: Optional[Literal["mixed", "variable", "constant"]] = None @field_validator("type", mode="before") diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 25a4112998..fef434e3ec 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -39,15 +39,30 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): def _run(self) -> NodeRunResult: # type: ignore node_data = cast(KnowledgeIndexNodeData, self.node_data) variable_pool = self.graph_runtime_state.variable_pool + dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) + if not dataset_id: + raise KnowledgeIndexNodeError("Dataset ID is required.") + dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first() + if not dataset: + raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.") + # extract variables variable = variable_pool.get(node_data.index_chunk_variable_selector) - is_preview = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) == InvokeFrom.DEBUGGER + if not variable: + raise KnowledgeIndexNodeError("Index chunk variable is required.") + invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + if invoke_from: + is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value + else: + is_preview = False chunks = variable.value variables = {"chunks": chunks} if not chunks: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." ) + outputs = self._get_preview_output(dataset.chunk_structure, chunks) + # retrieve knowledge try: if is_preview: @@ -55,12 +70,12 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, - outputs={"result": "success"}, + outputs=outputs, ) - results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks, variable_pool=variable_pool) - outputs = {"result": results} + results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks, + variable_pool=variable_pool) return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results ) except KnowledgeIndexNodeError as e: @@ -81,24 +96,18 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): ) def _invoke_knowledge_index( - self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], variable_pool: VariablePool + self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], + variable_pool: VariablePool ) -> Any: - dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) - if not dataset_id: - raise KnowledgeIndexNodeError("Dataset ID is required.") document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) if not document_id: raise KnowledgeIndexNodeError("Document ID is required.") batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) if not batch: raise KnowledgeIndexNodeError("Batch is required.") - dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() - if not dataset: - raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") - - document = db.session.query(Document).filter_by(id=document_id).first() + document = db.session.query(Document).filter_by(id=document_id.value).first() if not document: - raise KnowledgeIndexNodeError(f"Document {document_id} not found.") + raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.") index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor() index_processor.index(dataset, document, chunks) @@ -106,14 +115,19 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): # update document status document.indexing_status = "completed" document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.add(document) db.session.commit() return { "dataset_id": dataset.id, "dataset_name": dataset.name, - "batch": batch, + "batch": batch.value, "document_id": document.id, "document_name": document.name, - "created_at": document.created_at, + "created_at": document.created_at.timestamp(), "display_status": document.indexing_status, } + + def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]: + index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() + return index_processor.format_preview(chunks) diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 0733192c4f..c138266b14 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -41,10 +41,9 @@ conversation_variable_fields = { } pipeline_variable_fields = { - "id": fields.String, "label": fields.String, "variable": fields.String, - "type": fields.String(attribute="type.value"), + "type": fields.String, "belong_to_node_id": fields.String, "max_length": fields.Integer, "required": fields.Boolean, diff --git a/api/models/enums.py b/api/models/enums.py index 4434c3fec8..0afa204b1f 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -14,6 +14,8 @@ class UserFrom(StrEnum): class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" APP_RUN = "app-run" + RAG_PIPELINE_RUN = "rag-pipeline-run" + RAG_PIPELINE_DEBUGGING = "rag-pipeline-debugging" class DraftVariableType(StrEnum): diff --git a/api/models/workflow.py b/api/models/workflow.py index b37b0febe8..f0aba3572a 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -152,6 +152,7 @@ class Workflow(Base): created_by: str, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], + rag_pipeline_variables: list[dict], marked_name: str = "", marked_comment: str = "", ) -> "Workflow": @@ -166,6 +167,7 @@ class Workflow(Base): workflow.created_by = created_by workflow.environment_variables = environment_variables or [] workflow.conversation_variables = conversation_variables or [] + workflow.rag_pipeline_variables = rag_pipeline_variables or [] workflow.marked_name = marked_name workflow.marked_comment = marked_comment workflow.created_at = datetime.now(UTC).replace(tzinfo=None) @@ -340,7 +342,7 @@ class Workflow(Base): "features": self.features_dict, "environment_variables": [var.model_dump(mode="json") for var in environment_variables], "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], - "rag_pipeline_variables": [var.model_dump(mode="json") for var in self.rag_pipeline_variables], + "rag_pipeline_variables": self.rag_pipeline_variables, } return result @@ -553,6 +555,7 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum): SINGLE_STEP = "single-step" WORKFLOW_RUN = "workflow-run" + RAG_PIPELINE_RUN = "rag-pipeline-run" class WorkflowNodeExecutionStatus(StrEnum): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index f1280375e0..6d3891799c 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -51,7 +51,10 @@ from services.entities.knowledge_entities.knowledge_entities import ( RetrievalModel, SegmentUpdateArgs, ) -from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity +from services.entities.knowledge_entities.rag_pipeline_entities import ( + KnowledgeBaseUpdateConfiguration, + RagPipelineDatasetCreateEntity, +) from services.errors.account import InvalidActionError, NoPermissionError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.dataset import DatasetNameDuplicateError @@ -59,11 +62,11 @@ from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureModel, FeatureService -from services.rag_pipeline.rag_pipeline_dsl_service import ImportMode, RagPipelineDslService, RagPipelineImportInfo from services.tag_service import TagService from services.vector_service import VectorService from tasks.batch_clean_document_task import batch_clean_document_task from tasks.clean_notion_document_task import clean_notion_document_task +from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task @@ -278,47 +281,6 @@ class DatasetService: db.session.commit() return dataset - @staticmethod - def create_rag_pipeline_dataset( - tenant_id: str, - rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, - ): - # check if dataset name already exists - if ( - db.session.query(Dataset) - .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) - .first() - ): - raise DatasetNameDuplicateError( - f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." - ) - - dataset = Dataset( - name=rag_pipeline_dataset_create_entity.name, - description=rag_pipeline_dataset_create_entity.description, - permission=rag_pipeline_dataset_create_entity.permission, - provider="vendor", - runtime_mode="rag-pipeline", - icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(), - ) - with Session(db.engine) as session: - rag_pipeline_dsl_service = RagPipelineDslService(session) - account = cast(Account, current_user) - rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline( - account=account, - import_mode=ImportMode.YAML_CONTENT.value, - yaml_content=rag_pipeline_dataset_create_entity.yaml_content, - dataset=dataset, - ) - return { - "id": rag_pipeline_import_info.id, - "dataset_id": dataset.id, - "pipeline_id": rag_pipeline_import_info.pipeline_id, - "status": rag_pipeline_import_info.status, - "imported_dsl_version": rag_pipeline_import_info.imported_dsl_version, - "current_dsl_version": rag_pipeline_import_info.current_dsl_version, - "error": rag_pipeline_import_info.error, - } @staticmethod def get_dataset(dataset_id) -> Optional[Dataset]: @@ -529,6 +491,130 @@ class DatasetService: if action: deal_dataset_vector_index_task.delay(dataset_id, action) return dataset + + @staticmethod + def update_rag_pipeline_dataset_settings(session: Session, + dataset: Dataset, + knowledge_base_setting: KnowledgeBaseUpdateConfiguration, + has_published: bool = False): + if not has_published: + dataset.chunk_structure = knowledge_base_setting.chunk_structure + index_method = knowledge_base_setting.index_method + dataset.indexing_technique = index_method.indexing_technique + if index_method == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=index_method.embedding_setting.embedding_provider_name, + model_type=ModelType.TEXT_EMBEDDING, + model=index_method.embedding_setting.embedding_model_name, + ) + dataset.embedding_model = embedding_model.model + dataset.embedding_model_provider = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + dataset.collection_binding_id = dataset_collection_binding.id + elif index_method == "economy": + dataset.keyword_number = index_method.economy_setting.keyword_number + else: + raise ValueError("Invalid index method") + dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() + session.add(dataset) + else: + if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure: + raise ValueError("Chunk structure is not allowed to be updated.") + action = None + if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique: + # if update indexing_technique + if knowledge_base_setting.index_method.indexing_technique == "economy": + raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") + elif knowledge_base_setting.index_method.indexing_technique == "high_quality": + action = "add" + # get embedding model setting + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name, + model_type=ModelType.TEXT_EMBEDDING, + model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name, + ) + dataset.embedding_model = embedding_model.model + dataset.embedding_model_provider = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + dataset.collection_binding_id = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + else: + # add default plugin id to both setting sets, to make sure the plugin model provider is consistent + # Skip embedding model checks if not provided in the update request + if dataset.indexing_technique == "high_quality": + skip_embedding_update = False + try: + # Handle existing model provider + plugin_model_provider = dataset.embedding_model_provider + plugin_model_provider_str = None + if plugin_model_provider: + plugin_model_provider_str = str(ModelProviderID(plugin_model_provider)) + + # Handle new model provider from request + new_plugin_model_provider = knowledge_base_setting.index_method.embedding_setting.embedding_provider_name + new_plugin_model_provider_str = None + if new_plugin_model_provider: + new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider)) + + # Only update embedding model if both values are provided and different from current + if ( + plugin_model_provider_str != new_plugin_model_provider_str + or knowledge_base_setting.index_method.embedding_setting.embedding_model_name != dataset.embedding_model + ): + action = "update" + model_manager = ModelManager() + try: + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name, + model_type=ModelType.TEXT_EMBEDDING, + model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name, + ) + except ProviderTokenNotInitError: + # If we can't get the embedding model, skip updating it + # and keep the existing settings if available + # Skip the rest of the embedding model update + skip_embedding_update = True + if not skip_embedding_update: + dataset.embedding_model = embedding_model.model + dataset.embedding_model_provider = embedding_model.provider + dataset_collection_binding = ( + DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + ) + dataset.collection_binding_id = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + elif dataset.indexing_technique == "economy": + if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number: + dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number + dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() + session.add(dataset) + session.commit() + if action: + deal_dataset_index_update_task.delay(dataset.id, action) + @staticmethod def delete_dataset(dataset_id, user): diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index fbb9b25a75..54abc64547 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -4,29 +4,12 @@ from typing import Optional from flask_login import current_user from constants import HIDDEN_VALUE -from core import datasource -from core.datasource.__base import datasource_provider -from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity from core.helper import encrypter -from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.entities.provider_entities import FormType from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.plugin.impl.datasource import PluginDatasourceManager -from core.provider_manager import ProviderManager +from extensions.ext_database import db from models.oauth import DatasourceProvider -from models.provider import ProviderType -from services.entities.model_provider_entities import ( - CustomConfigurationResponse, - CustomConfigurationStatus, - DefaultModelResponse, - ModelWithProviderEntityResponse, - ProviderResponse, - ProviderWithModelsResponse, - SimpleProviderEntityResponse, - SystemConfigurationResponse, -) -from extensions.database import db logger = logging.getLogger(__name__) @@ -115,16 +98,26 @@ class DatasourceProviderService: :param tenant_id: workspace id :param provider: provider name - :param datasource_name: datasource name :param plugin_id: plugin id :return: """ # Get all provider configurations of the current workspace - datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, + datasource_provider: DatasourceProvider | None = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id).first() + if not datasource_provider: + return None + encrypted_credentials = datasource_provider.encrypted_credentials + # Get provider credential secret variables + credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) + # Obfuscate provider credentials + copy_credentials = encrypted_credentials.copy() + for key, value in copy_credentials.items(): + if key in credential_secret_variables: + copy_credentials[key] = encrypter.obfuscated_token(value) + return copy_credentials def remove_datasource_credentials(self, @@ -136,11 +129,9 @@ class DatasourceProviderService: :param tenant_id: workspace id :param provider: provider name - :param datasource_name: datasource name :param plugin_id: plugin id :return: """ - # Get all provider configurations of the current workspace datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id).first() diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index 800bd24021..17416d51fd 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -111,3 +111,12 @@ class KnowledgeConfiguration(BaseModel): chunk_structure: str index_method: IndexMethod retrieval_setting: RetrievalSetting + + +class KnowledgeBaseUpdateConfiguration(BaseModel): + """ + Knowledge Base Update Configuration. + """ + index_method: IndexMethod + chunk_structure: str + retrieval_setting: RetrievalSetting \ No newline at end of file diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 14594be351..911086066a 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -69,9 +69,9 @@ class PipelineGenerateService: @classmethod def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True): workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER) - return WorkflowAppGenerator.convert_to_event_stream( - WorkflowAppGenerator().single_loop_generate( - app_model=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + return PipelineGenerator.convert_to_event_stream( + PipelineGenerator().single_loop_generate( + pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming ) ) diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 11071d82e7..9ea3cc678b 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -36,7 +36,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): recommended_pipelines_results = [] for pipeline_built_in_template in pipeline_built_in_templates: - pipeline_model: Pipeline = pipeline_built_in_template.pipeline + pipeline_model: Pipeline | None = pipeline_built_in_template.pipeline + if not pipeline_model: + continue recommended_pipeline_result = { "id": pipeline_built_in_template.id, @@ -48,7 +50,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): "privacy_policy": pipeline_built_in_template.privacy_policy, "position": pipeline_built_in_template.position, } - dataset: Dataset = pipeline_model.dataset + dataset: Dataset | None = pipeline_model.dataset if dataset: recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure recommended_pipelines_results.append(recommended_pipeline_result) @@ -72,15 +74,19 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): if not pipeline_template: return None - # get app detail + # get pipeline detail pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first() if not pipeline or not pipeline.is_public: return None + dataset: Dataset | None = pipeline.dataset + if not dataset: + return None + return { "id": pipeline.id, "name": pipeline.name, - "icon": pipeline.icon, - "mode": pipeline.mode, + "icon": pipeline_template.icon, + "chunk_structure": dataset.chunk_structure, "export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline), } diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 5a69e69a16..9e7a1d7fe2 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -46,7 +46,8 @@ from models.workflow import ( WorkflowRun, WorkflowType, ) -from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity +from services.dataset_service import DatasetService +from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory @@ -261,8 +262,7 @@ class RagPipelineService: session: Session, pipeline: Pipeline, account: Account, - marked_name: str = "", - marked_comment: str = "", + knowledge_base_setting: KnowledgeBaseUpdateConfiguration, ) -> Workflow: draft_workflow_stmt = select(Workflow).where( Workflow.tenant_id == pipeline.tenant_id, @@ -282,18 +282,25 @@ class RagPipelineService: graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, - environment_variables=draft_workflow.environment_variables, + environment_variables=draft_workflow.environment_variables, conversation_variables=draft_workflow.conversation_variables, - marked_name=marked_name, - marked_comment=marked_comment, + rag_pipeline_variables=draft_workflow.rag_pipeline_variables, + marked_name="", + marked_comment="", ) - # commit db session changes session.add(workflow) - # trigger app workflow events TODO - # app_published_workflow_was_updated.send(pipeline, published_workflow=workflow) - + # update dataset + dataset = pipeline.dataset + if not dataset: + raise ValueError("Dataset not found") + DatasetService.update_rag_pipeline_dataset_settings( + session=session, + dataset=dataset, + knowledge_base_setting=knowledge_base_setting, + has_published=pipeline.is_published + ) # return new workflow return workflow diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index acd364f6cd..c6751825cc 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -4,13 +4,14 @@ import logging import uuid from collections.abc import Mapping from enum import StrEnum -from typing import Optional +from typing import Optional, cast from urllib.parse import urlparse from uuid import uuid4 import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad +from flask_login import current_user from packaging import version from pydantic import BaseModel, Field from sqlalchemy import select @@ -31,7 +32,10 @@ from factories import variable_factory from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.workflow import Workflow -from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration +from services.entities.knowledge_entities.rag_pipeline_entities import ( + KnowledgeConfiguration, + RagPipelineDatasetCreateEntity, +) from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -540,9 +544,6 @@ class RagPipelineDslService: # Update existing pipeline pipeline.name = pipeline_data.get("name", pipeline.name) pipeline.description = pipeline_data.get("description", pipeline.description) - pipeline.icon_type = icon_type - pipeline.icon = icon - pipeline.icon_background = pipeline_data.get("icon_background", pipeline.icon_background) pipeline.updated_by = account.id else: if account.current_tenant_id is None: @@ -554,12 +555,6 @@ class RagPipelineDslService: pipeline.tenant_id = account.current_tenant_id pipeline.name = pipeline_data.get("name", "") pipeline.description = pipeline_data.get("description", "") - pipeline.icon_type = icon_type - pipeline.icon = icon - pipeline.icon_background = pipeline_data.get("icon_background", "#FFFFFF") - pipeline.enable_site = True - pipeline.enable_api = True - pipeline.use_icon_as_answer_icon = pipeline_data.get("use_icon_as_answer_icon", False) pipeline.created_by = account.id pipeline.updated_by = account.id @@ -674,26 +669,6 @@ class RagPipelineDslService: ) ] - @classmethod - def _append_model_config_export_data(cls, export_data: dict, pipeline: Pipeline) -> None: - """ - Append model config export data - :param export_data: export data - :param pipeline: Pipeline instance - """ - app_model_config = pipeline.app_model_config - if not app_model_config: - raise ValueError("Missing app configuration, please check.") - - export_data["model_config"] = app_model_config.to_dict() - dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict()) - export_data["dependencies"] = [ - jsonable_encoder(d.model_dump()) - for d in DependenciesAnalysisService.generate_dependencies( - tenant_id=pipeline.tenant_id, dependencies=dependencies - ) - ] - @classmethod def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]: """ @@ -863,3 +838,46 @@ class RagPipelineDslService: return pt.decode() except Exception: return None + + + @staticmethod + def create_rag_pipeline_dataset( + tenant_id: str, + rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, + ): + # check if dataset name already exists + if ( + db.session.query(Dataset) + .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) + .first() + ): + raise ValueError( + f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." + ) + + dataset = Dataset( + name=rag_pipeline_dataset_create_entity.name, + description=rag_pipeline_dataset_create_entity.description, + permission=rag_pipeline_dataset_create_entity.permission, + provider="vendor", + runtime_mode="rag-pipeline", + icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(), + ) + with Session(db.engine) as session: + rag_pipeline_dsl_service = RagPipelineDslService(session) + account = cast(Account, current_user) + rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline( + account=account, + import_mode=ImportMode.YAML_CONTENT.value, + yaml_content=rag_pipeline_dataset_create_entity.yaml_content, + dataset=dataset, + ) + return { + "id": rag_pipeline_import_info.id, + "dataset_id": dataset.id, + "pipeline_id": rag_pipeline_import_info.pipeline_id, + "status": rag_pipeline_import_info.status, + "imported_dsl_version": rag_pipeline_import_info.imported_dsl_version, + "current_dsl_version": rag_pipeline_import_info.current_dsl_version, + "error": rag_pipeline_import_info.error, + } diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py new file mode 100644 index 0000000000..dc266aef65 --- /dev/null +++ b/api/tasks/deal_dataset_index_update_task.py @@ -0,0 +1,171 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def deal_dataset_index_update_task(dataset_id: str, action: str): + """ + Async deal dataset from index + :param dataset_id: dataset_id + :param action: action + Usage: deal_dataset_index_update_task.delay(dataset_id, action) + """ + logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green")) + start_at = time.perf_counter() + + try: + dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() + + if not dataset: + raise Exception("Dataset not found") + index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_processor = IndexProcessorFactory(index_type).init_index_processor() + if action == "upgrade": + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) + + if dataset_documents: + dataset_documents_ids = [doc.id for doc in dataset_documents] + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + db.session.commit() + + for dataset_document in dataset_documents: + try: + # add from vector index + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + + documents.append(document) + # save vector index + # clean keywords + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) + index_processor.load(dataset, documents, with_keywords=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + db.session.commit() + except Exception as e: + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + db.session.commit() + elif action == "update": + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) + # add new index + if dataset_documents: + # update document status + dataset_documents_ids = [doc.id for doc in dataset_documents] + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + db.session.commit() + + # clean index + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + for dataset_document in dataset_documents: + # update from vector index + try: + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + documents.append(document) + # save vector index + index_processor.load(dataset, documents, with_keywords=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + db.session.commit() + except Exception as e: + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + db.session.commit() + else: + # clean collection + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + end_at = time.perf_counter() + logging.info( + click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Deal dataset vector index failed") + finally: + db.session.close()