mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 12:37:20 +08:00
r2
This commit is contained in:
parent
5fc2bc58a9
commit
7f59ffe7af
@ -15,6 +15,7 @@ from libs.login import login_required
|
|||||||
from models.dataset import DatasetPermissionEnum
|
from models.dataset import DatasetPermissionEnum
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||||
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
|
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
|
||||||
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name):
|
def _validate_name(name):
|
||||||
@ -91,7 +92,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
|
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
|
||||||
try:
|
try:
|
||||||
import_info = DatasetService.create_rag_pipeline_dataset(
|
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -40,6 +40,7 @@ from libs.login import current_user, login_required
|
|||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration
|
||||||
from services.errors.app import WorkflowHashNotEqualError
|
from services.errors.app import WorkflowHashNotEqualError
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||||
@ -282,15 +283,18 @@ class PublishedRagPipelineRunApi(Resource):
|
|||||||
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
|
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
||||||
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
||||||
|
parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
streaming = args["response_mode"] == "streaming"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = PipelineGenerateService.generate(
|
response = PipelineGenerateService.generate(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
args=args,
|
args=args,
|
||||||
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED,
|
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED,
|
||||||
streaming=True,
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
@ -459,16 +463,17 @@ class PublishedRagPipelineApi(Resource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
|
parser.add_argument("knowledge_base_setting", type=dict, location="json", help="Invalid knowledge base setting.")
|
||||||
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate name and comment length
|
if not args.get("knowledge_base_setting"):
|
||||||
if args.marked_name and len(args.marked_name) > 20:
|
raise ValueError("Missing knowledge base setting.")
|
||||||
raise ValueError("Marked name cannot exceed 20 characters")
|
|
||||||
if args.marked_comment and len(args.marked_comment) > 100:
|
|
||||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
|
||||||
|
|
||||||
|
knowledge_base_setting_data = args.get("knowledge_base_setting")
|
||||||
|
if not knowledge_base_setting_data:
|
||||||
|
raise ValueError("Missing knowledge base setting.")
|
||||||
|
|
||||||
|
knowledge_base_setting = KnowledgeBaseUpdateConfiguration(**knowledge_base_setting_data)
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
pipeline = session.merge(pipeline)
|
pipeline = session.merge(pipeline)
|
||||||
@ -476,8 +481,7 @@ class PublishedRagPipelineApi(Resource):
|
|||||||
session=session,
|
session=session,
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
account=current_user,
|
account=current_user,
|
||||||
marked_name=args.marked_name or "",
|
knowledge_base_setting=knowledge_base_setting,
|
||||||
marked_comment=args.marked_comment or "",
|
|
||||||
)
|
)
|
||||||
pipeline.is_published = True
|
pipeline.is_published = True
|
||||||
pipeline.workflow_id = workflow.id
|
pipeline.workflow_id = workflow.id
|
||||||
|
|||||||
@ -28,10 +28,13 @@ from core.app.entities.task_entities import WorkflowAppBlockingResponse, Workflo
|
|||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||||
|
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||||
from models.dataset import Document, Pipeline
|
from models.dataset import Document, Pipeline
|
||||||
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.dataset_service import DocumentService
|
from services.dataset_service import DocumentService
|
||||||
|
|
||||||
@ -51,7 +54,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
streaming: Literal[True],
|
streaming: Literal[True],
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_thread_pool_id: Optional[str],
|
workflow_thread_pool_id: Optional[str],
|
||||||
) -> Generator[Mapping | str, None, None]: ...
|
) -> Generator[Mapping | str, None, None] | None: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def generate(
|
def generate(
|
||||||
@ -92,7 +95,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
call_depth: int = 0,
|
call_depth: int = 0,
|
||||||
workflow_thread_pool_id: Optional[str] = None,
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||||
# convert to app config
|
# convert to app config
|
||||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
@ -119,14 +122,14 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
document = self._build_document(
|
document = self._build_document(
|
||||||
tenant_id=pipeline.tenant_id,
|
tenant_id=pipeline.tenant_id,
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
built_in_field_enabled=pipeline.dataset.built_in_field_enabled,
|
built_in_field_enabled=dataset.built_in_field_enabled,
|
||||||
datasource_type=datasource_type,
|
datasource_type=datasource_type,
|
||||||
datasource_info=datasource_info,
|
datasource_info=datasource_info,
|
||||||
created_from="rag-pipeline",
|
created_from="rag-pipeline",
|
||||||
position=position,
|
position=position,
|
||||||
account=user,
|
account=user,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
document_form=pipeline.dataset.chunk_structure,
|
document_form=dataset.chunk_structure,
|
||||||
)
|
)
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -138,7 +141,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
pipeline_config=pipeline_config,
|
pipeline_config=pipeline_config,
|
||||||
datasource_type=datasource_type,
|
datasource_type=datasource_type,
|
||||||
datasource_info=datasource_info,
|
datasource_info=datasource_info,
|
||||||
dataset_id=pipeline.dataset.id,
|
dataset_id=dataset.id,
|
||||||
start_node_id=start_node_id,
|
start_node_id=start_node_id,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
@ -159,15 +162,24 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||||
contexts.plugin_tool_providers.set({})
|
contexts.plugin_tool_providers.set({})
|
||||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
|
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
|
||||||
|
else:
|
||||||
|
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
|
||||||
# Create workflow node execution repository
|
# Create workflow node execution repository
|
||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=workflow_triggered_from,
|
||||||
|
)
|
||||||
|
|
||||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
user=user,
|
user=user,
|
||||||
app_id=application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||||
)
|
)
|
||||||
if invoke_from == InvokeFrom.DEBUGGER:
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
return self._generate(
|
return self._generate(
|
||||||
@ -176,6 +188,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
user=user,
|
user=user,
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||||
@ -187,6 +200,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
user=user,
|
user=user,
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||||
@ -200,6 +214,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
application_generate_entity: RagPipelineGenerateEntity,
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
workflow_thread_pool_id: Optional[str] = None,
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
@ -207,11 +222,12 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
:param app_model: App
|
:param pipeline: Pipeline
|
||||||
:param workflow: Workflow
|
:param workflow: Workflow
|
||||||
:param user: account or end user
|
:param user: account or end user
|
||||||
:param application_generate_entity: application generate entity
|
:param application_generate_entity: application generate entity
|
||||||
:param invoke_from: invoke from source
|
:param invoke_from: invoke from source
|
||||||
|
:param workflow_execution_repository: repository for workflow execution
|
||||||
:param workflow_node_execution_repository: repository for workflow node execution
|
:param workflow_node_execution_repository: repository for workflow node execution
|
||||||
:param streaming: is stream
|
:param streaming: is stream
|
||||||
:param workflow_thread_pool_id: workflow thread pool id
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
@ -244,6 +260,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
user=user,
|
user=user,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
stream=streaming,
|
stream=streaming,
|
||||||
)
|
)
|
||||||
@ -276,16 +293,20 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
raise ValueError("inputs is required")
|
raise ValueError("inputs is required")
|
||||||
|
|
||||||
# convert to app config
|
# convert to app config
|
||||||
app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||||
|
|
||||||
|
dataset = pipeline.dataset
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("Pipeline dataset is required")
|
||||||
|
|
||||||
# init application generate entity - use RagPipelineGenerateEntity instead
|
# init application generate entity - use RagPipelineGenerateEntity instead
|
||||||
application_generate_entity = RagPipelineGenerateEntity(
|
application_generate_entity = RagPipelineGenerateEntity(
|
||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=pipeline_config,
|
||||||
pipeline_config=app_config,
|
pipeline_config=pipeline_config,
|
||||||
datasource_type=args.get("datasource_type", ""),
|
datasource_type=args.get("datasource_type", ""),
|
||||||
datasource_info=args.get("datasource_info", {}),
|
datasource_info=args.get("datasource_info", {}),
|
||||||
dataset_id=pipeline.dataset_id,
|
dataset_id=dataset.id,
|
||||||
batch=args.get("batch", ""),
|
batch=args.get("batch", ""),
|
||||||
document_id=args.get("document_id"),
|
document_id=args.get("document_id"),
|
||||||
inputs={},
|
inputs={},
|
||||||
@ -299,10 +320,16 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||||
contexts.plugin_tool_providers.set({})
|
contexts.plugin_tool_providers.set({})
|
||||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
|
||||||
# Create workflow node execution repository
|
# Create workflow node execution repository
|
||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||||
|
)
|
||||||
|
|
||||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
user=user,
|
user=user,
|
||||||
@ -316,6 +343,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
user=user,
|
user=user,
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
@ -345,20 +373,30 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
if args.get("inputs") is None:
|
if args.get("inputs") is None:
|
||||||
raise ValueError("inputs is required")
|
raise ValueError("inputs is required")
|
||||||
|
|
||||||
|
dataset = pipeline.dataset
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("Pipeline dataset is required")
|
||||||
|
|
||||||
# convert to app config
|
# convert to app config
|
||||||
app_config = WorkflowAppConfigManager.get_app_config(pipeline=pipeline, workflow=workflow)
|
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||||
|
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
application_generate_entity = WorkflowAppGenerateEntity(
|
application_generate_entity = RagPipelineGenerateEntity(
|
||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=pipeline_config,
|
||||||
|
pipeline_config=pipeline_config,
|
||||||
|
datasource_type=args.get("datasource_type", ""),
|
||||||
|
datasource_info=args.get("datasource_info", {}),
|
||||||
|
batch=args.get("batch", ""),
|
||||||
|
document_id=args.get("document_id"),
|
||||||
|
dataset_id=dataset.id,
|
||||||
inputs={},
|
inputs={},
|
||||||
files=[],
|
files=[],
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=streaming,
|
stream=streaming,
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
extras={"auto_generate_conversation_name": False},
|
extras={"auto_generate_conversation_name": False},
|
||||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||||
workflow_run_id=str(uuid.uuid4()),
|
workflow_run_id=str(uuid.uuid4()),
|
||||||
)
|
)
|
||||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||||
@ -368,6 +406,13 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
# Create workflow node execution repository
|
# Create workflow node execution repository
|
||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||||
|
)
|
||||||
|
|
||||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
user=user,
|
user=user,
|
||||||
@ -381,6 +426,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
user=user,
|
user=user,
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
@ -438,6 +484,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||||
@ -459,6 +506,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
user=user,
|
user=user,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -481,7 +529,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
datasource_info: Mapping[str, Any],
|
datasource_info: Mapping[str, Any],
|
||||||
created_from: str,
|
created_from: str,
|
||||||
position: int,
|
position: int,
|
||||||
account: Account,
|
account: Union[Account, EndUser],
|
||||||
batch: str,
|
batch: str,
|
||||||
document_form: str,
|
document_form: str,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
InvokeFrom,
|
InvokeFrom,
|
||||||
RagPipelineGenerateEntity,
|
RagPipelineGenerateEntity,
|
||||||
)
|
)
|
||||||
|
from core.variables.variables import RAGPipelineVariable
|
||||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
@ -106,12 +107,19 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||||||
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
|
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
|
||||||
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value,
|
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value,
|
||||||
}
|
}
|
||||||
|
rag_pipeline_variables = {}
|
||||||
|
if workflow.rag_pipeline_variables:
|
||||||
|
for v in workflow.rag_pipeline_variables:
|
||||||
|
rag_pipeline_variable = RAGPipelineVariable(**v)
|
||||||
|
if rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id and rag_pipeline_variable.variable in inputs:
|
||||||
|
rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable]
|
||||||
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables=system_inputs,
|
system_variables=system_inputs,
|
||||||
user_inputs=inputs,
|
user_inputs=inputs,
|
||||||
environment_variables=workflow.environment_variables,
|
environment_variables=workflow.environment_variables,
|
||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
|
rag_pipeline_variables=rag_pipeline_variables,
|
||||||
)
|
)
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
|
|||||||
@ -9,10 +9,10 @@ from core.tools.errors import ToolProviderCredentialValidationError
|
|||||||
|
|
||||||
|
|
||||||
class DatasourcePluginProviderController(ABC):
|
class DatasourcePluginProviderController(ABC):
|
||||||
entity: DatasourceProviderEntityWithPlugin | None
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
|
|
||||||
def __init__(self, entity: DatasourceProviderEntityWithPlugin | None, tenant_id: str) -> None:
|
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
|
||||||
self.entity = entity
|
self.entity = entity
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
|
|
||||||
|
|||||||
@ -14,9 +14,9 @@ class DatasourceRuntime(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
tool_id: Optional[str] = None
|
datasource_id: Optional[str] = None
|
||||||
invoke_from: Optional[InvokeFrom] = None
|
invoke_from: Optional[InvokeFrom] = None
|
||||||
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
|
||||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
entity: DatasourceProviderEntityWithPlugin | None,
|
entity: DatasourceProviderEntityWithPlugin,
|
||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
plugin_unique_identifier: str,
|
plugin_unique_identifier: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
|||||||
@ -30,22 +30,16 @@ class PluginDatasourceManager(BasePluginClient):
|
|||||||
|
|
||||||
return json_response
|
return json_response
|
||||||
|
|
||||||
# response = self._request_with_plugin_daemon_response(
|
response = self._request_with_plugin_daemon_response(
|
||||||
# "GET",
|
"GET",
|
||||||
# f"plugin/{tenant_id}/management/datasources",
|
f"plugin/{tenant_id}/management/datasources",
|
||||||
# list[PluginDatasourceProviderEntity],
|
list[PluginDatasourceProviderEntity],
|
||||||
# params={"page": 1, "page_size": 256},
|
params={"page": 1, "page_size": 256},
|
||||||
# transformer=transformer,
|
transformer=transformer,
|
||||||
# )
|
)
|
||||||
|
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||||
|
|
||||||
# for provider in response:
|
return [local_file_datasource_provider] + response
|
||||||
# provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
|
||||||
|
|
||||||
# # override the provider name for each tool to plugin_id/provider_name
|
|
||||||
# for datasource in provider.declaration.datasources:
|
|
||||||
# datasource.identity.provider = provider.declaration.identity.name
|
|
||||||
|
|
||||||
return [PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())]
|
|
||||||
|
|
||||||
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
|
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -13,7 +13,8 @@ from core.rag.splitter.fixed_text_splitter import (
|
|||||||
FixedRecursiveCharacterTextSplitter,
|
FixedRecursiveCharacterTextSplitter,
|
||||||
)
|
)
|
||||||
from core.rag.splitter.text_splitter import TextSplitter
|
from core.rag.splitter.text_splitter import TextSplitter
|
||||||
from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule
|
from models.dataset import Dataset, DatasetProcessRule
|
||||||
|
from models.dataset import Document as DatasetDocument
|
||||||
|
|
||||||
|
|
||||||
class BaseIndexProcessor(ABC):
|
class BaseIndexProcessor(ABC):
|
||||||
@ -37,6 +38,10 @@ class BaseIndexProcessor(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def retrieve(
|
def retrieve(
|
||||||
|
|||||||
@ -131,7 +131,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
||||||
paragraph = GeneralStructureChunk(**chunks)
|
paragraph = GeneralStructureChunk(**chunks)
|
||||||
documents = []
|
documents = []
|
||||||
for content in paragraph.general_chunk:
|
for content in paragraph.general_chunks:
|
||||||
metadata = {
|
metadata = {
|
||||||
"dataset_id": dataset.id,
|
"dataset_id": dataset.id,
|
||||||
"document_id": document.id,
|
"document_id": document.id,
|
||||||
@ -151,3 +151,14 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||||||
elif dataset.indexing_technique == "economy":
|
elif dataset.indexing_technique == "economy":
|
||||||
keyword = Keyword(dataset)
|
keyword = Keyword(dataset)
|
||||||
keyword.add_texts(documents)
|
keyword.add_texts(documents)
|
||||||
|
|
||||||
|
|
||||||
|
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
paragraph = GeneralStructureChunk(**chunks)
|
||||||
|
preview = []
|
||||||
|
for content in paragraph.general_chunks:
|
||||||
|
preview.append({"content": content})
|
||||||
|
return {
|
||||||
|
"preview": preview,
|
||||||
|
"total_segments": len(paragraph.general_chunks)
|
||||||
|
}
|
||||||
@ -234,3 +234,19 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
if dataset.indexing_technique == "high_quality":
|
if dataset.indexing_technique == "high_quality":
|
||||||
vector = Vector(dataset)
|
vector = Vector(dataset)
|
||||||
vector.create(documents)
|
vector.create(documents)
|
||||||
|
|
||||||
|
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
parent_childs = ParentChildStructureChunk(**chunks)
|
||||||
|
preview = []
|
||||||
|
for parent_child in parent_childs.parent_child_chunks:
|
||||||
|
preview.append(
|
||||||
|
{
|
||||||
|
"content": parent_child.parent_content,
|
||||||
|
"child_chunks": parent_child.child_contents
|
||||||
|
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"preview": preview,
|
||||||
|
"total_segments": len(parent_childs.parent_child_chunks)
|
||||||
|
}
|
||||||
@ -4,7 +4,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
@ -20,7 +20,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
|||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset, Document as DatasetDocument
|
||||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||||
|
|
||||||
|
|
||||||
@ -160,6 +160,12 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
return {"preview": chunks}
|
||||||
|
|
||||||
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
||||||
format_documents = []
|
format_documents = []
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class GeneralStructureChunk(BaseModel):
|
|||||||
General Structure Chunk.
|
General Structure Chunk.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
general_chunk: list[str]
|
general_chunks: list[str]
|
||||||
|
|
||||||
|
|
||||||
class ParentChildChunk(BaseModel):
|
class ParentChildChunk(BaseModel):
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from collections.abc import Sequence
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
|
|
||||||
@ -93,3 +93,20 @@ class FileVariable(FileSegment, Variable):
|
|||||||
|
|
||||||
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class RAGPipelineVariable(BaseModel):
|
||||||
|
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
|
||||||
|
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||||
|
label: str = Field(description="label")
|
||||||
|
description: str | None = Field(description="description", default="")
|
||||||
|
variable: str = Field(description="variable key", default="")
|
||||||
|
max_length: int | None = Field(description="max length, applicable to text-input, paragraph, and file-list", default=0)
|
||||||
|
default_value: str | None = Field(description="default value", default="")
|
||||||
|
placeholder: str | None = Field(description="placeholder", default="")
|
||||||
|
unit: str | None = Field(description="unit, applicable to Number", default="")
|
||||||
|
tooltips: str | None = Field(description="helpful text", default="")
|
||||||
|
allowed_file_types: list[str] | None = Field(description="image, document, audio, video, custom.", default_factory=list)
|
||||||
|
allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list)
|
||||||
|
allowed_file_upload_methods: list[str] | None = Field(description="remote_url, local_file, tool_file.", default_factory=list)
|
||||||
|
required: bool = Field(description="optional, default false", default=False)
|
||||||
|
options: list[str] | None = Field(default_factory=list)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||||
PIPELINE_VARIABLE_NODE_ID = "pipeline"
|
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"
|
||||||
|
|||||||
@ -10,7 +10,12 @@ from core.variables import Segment, SegmentGroup, Variable
|
|||||||
from core.variables.segments import FileSegment, NoneSegment
|
from core.variables.segments import FileSegment, NoneSegment
|
||||||
from factories import variable_factory
|
from factories import variable_factory
|
||||||
|
|
||||||
from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from ..constants import (
|
||||||
|
CONVERSATION_VARIABLE_NODE_ID,
|
||||||
|
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||||
|
RAG_PIPELINE_VARIABLE_NODE_ID,
|
||||||
|
SYSTEM_VARIABLE_NODE_ID,
|
||||||
|
)
|
||||||
from ..enums import SystemVariableKey
|
from ..enums import SystemVariableKey
|
||||||
|
|
||||||
VariableValue = Union[str, int, float, dict, list, File]
|
VariableValue = Union[str, int, float, dict, list, File]
|
||||||
@ -42,6 +47,10 @@ class VariablePool(BaseModel):
|
|||||||
description="Conversation variables.",
|
description="Conversation variables.",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
rag_pipeline_variables: Mapping[str, Any] = Field(
|
||||||
|
description="RAG pipeline variables.",
|
||||||
|
default_factory=dict,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -50,18 +59,21 @@ class VariablePool(BaseModel):
|
|||||||
user_inputs: Mapping[str, Any] | None = None,
|
user_inputs: Mapping[str, Any] | None = None,
|
||||||
environment_variables: Sequence[Variable] | None = None,
|
environment_variables: Sequence[Variable] | None = None,
|
||||||
conversation_variables: Sequence[Variable] | None = None,
|
conversation_variables: Sequence[Variable] | None = None,
|
||||||
|
rag_pipeline_variables: Mapping[str, Any] | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
environment_variables = environment_variables or []
|
environment_variables = environment_variables or []
|
||||||
conversation_variables = conversation_variables or []
|
conversation_variables = conversation_variables or []
|
||||||
user_inputs = user_inputs or {}
|
user_inputs = user_inputs or {}
|
||||||
system_variables = system_variables or {}
|
system_variables = system_variables or {}
|
||||||
|
rag_pipeline_variables = rag_pipeline_variables or {}
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
system_variables=system_variables,
|
system_variables=system_variables,
|
||||||
user_inputs=user_inputs,
|
user_inputs=user_inputs,
|
||||||
environment_variables=environment_variables,
|
environment_variables=environment_variables,
|
||||||
conversation_variables=conversation_variables,
|
conversation_variables=conversation_variables,
|
||||||
|
rag_pipeline_variables=rag_pipeline_variables,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -73,6 +85,9 @@ class VariablePool(BaseModel):
|
|||||||
# Add conversation variables to the variable pool
|
# Add conversation variables to the variable pool
|
||||||
for var in self.conversation_variables:
|
for var in self.conversation_variables:
|
||||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||||
|
# Add rag pipeline variables to the variable pool
|
||||||
|
for var, value in self.rag_pipeline_variables.items():
|
||||||
|
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var), value)
|
||||||
|
|
||||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -20,6 +20,7 @@ class WorkflowType(StrEnum):
|
|||||||
|
|
||||||
WORKFLOW = "workflow"
|
WORKFLOW = "workflow"
|
||||||
CHAT = "chat"
|
CHAT = "chat"
|
||||||
|
RAG_PIPELINE = "rag-pipeline"
|
||||||
|
|
||||||
|
|
||||||
class WorkflowExecutionStatus(StrEnum):
|
class WorkflowExecutionStatus(StrEnum):
|
||||||
|
|||||||
@ -173,7 +173,7 @@ class GraphEngine:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
elif isinstance(item, NodeRunSucceededEvent):
|
elif isinstance(item, NodeRunSucceededEvent):
|
||||||
if item.node_type == NodeType.END:
|
if item.node_type in (NodeType.END, NodeType.KNOWLEDGE_INDEX):
|
||||||
self.graph_runtime_state.outputs = (
|
self.graph_runtime_state.outputs = (
|
||||||
dict(item.route_node_state.node_run_result.outputs)
|
dict(item.route_node_state.node_run_result.outputs)
|
||||||
if item.route_node_state.node_run_result
|
if item.route_node_state.node_run_result
|
||||||
@ -319,7 +319,7 @@ class GraphEngine:
|
|||||||
# It may not be necessary, but it is necessary. :)
|
# It may not be necessary, but it is necessary. :)
|
||||||
if (
|
if (
|
||||||
self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower()
|
self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower()
|
||||||
== NodeType.END.value
|
in [NodeType.END.value, NodeType.KNOWLEDGE_INDEX.value]
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@ -10,14 +10,16 @@ from core.datasource.entities.datasource_entities import (
|
|||||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||||
from core.file import File
|
from core.file import File
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
from core.variables.segments import ArrayAnySegment
|
from core.variables.segments import ArrayAnySegment, FileSegment
|
||||||
from core.variables.variables import ArrayAnyVariable
|
from core.variables.variables import ArrayAnyVariable
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.base import BaseNode
|
from core.workflow.nodes.base import BaseNode
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import NodeType
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import UploadFile
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
from .entities import DatasourceNodeData
|
from .entities import DatasourceNodeData
|
||||||
@ -59,7 +61,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
provider_id=node_data.provider_id,
|
provider_id=node_data.provider_id,
|
||||||
datasource_name=node_data.datasource_name or "",
|
datasource_name=node_data.datasource_name or "",
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
datasource_type=DatasourceProviderType(datasource_type),
|
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||||
)
|
)
|
||||||
except DatasourceNodeError as e:
|
except DatasourceNodeError as e:
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
@ -69,7 +71,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
error=f"Failed to get datasource runtime: {str(e)}",
|
error=f"Failed to get datasource runtime: {str(e)}",
|
||||||
error_type=type(e).__name__,
|
error_type=type(e).__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# get parameters
|
# get parameters
|
||||||
datasource_parameters = datasource_runtime.entity.parameters
|
datasource_parameters = datasource_runtime.entity.parameters
|
||||||
@ -105,7 +107,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
"datasource_type": datasource_type,
|
"datasource_type": datasource_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE:
|
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
inputs=parameters_for_log,
|
inputs=parameters_for_log,
|
||||||
@ -116,18 +118,42 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
case DatasourceProviderType.LOCAL_FILE:
|
case DatasourceProviderType.LOCAL_FILE:
|
||||||
|
upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first()
|
||||||
|
if not upload_file:
|
||||||
|
raise ValueError("Invalid upload file Info")
|
||||||
|
|
||||||
|
file_info = File(
|
||||||
|
id=upload_file.id,
|
||||||
|
filename=upload_file.name,
|
||||||
|
extension="." + upload_file.extension,
|
||||||
|
mime_type=upload_file.mime_type,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
type=datasource_info.get("type", ""),
|
||||||
|
transfer_method=datasource_info.get("transfer_method", ""),
|
||||||
|
remote_url=upload_file.source_url,
|
||||||
|
related_id=upload_file.id,
|
||||||
|
size=upload_file.size,
|
||||||
|
storage_key=upload_file.key,
|
||||||
|
)
|
||||||
|
variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)])
|
||||||
|
for key, value in datasource_info.items():
|
||||||
|
# construct new key list
|
||||||
|
new_key_list = ["file", key]
|
||||||
|
self._append_variables_recursively(
|
||||||
|
variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value
|
||||||
|
)
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
inputs=parameters_for_log,
|
inputs=parameters_for_log,
|
||||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
outputs={
|
outputs={
|
||||||
"file": datasource_info,
|
"file_info": file_info,
|
||||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
"datasource_type": datasource_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
raise DatasourceNodeError(
|
raise DatasourceNodeError(
|
||||||
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
|
f"Unsupported datasource provider: {datasource_type}"
|
||||||
)
|
)
|
||||||
except PluginDaemonClientSideError as e:
|
except PluginDaemonClientSideError as e:
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
@ -194,6 +220,26 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||||
return list(variable.value) if variable else []
|
return list(variable.value) if variable else []
|
||||||
|
|
||||||
|
|
||||||
|
def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||||
|
"""
|
||||||
|
Append variables recursively
|
||||||
|
:param node_id: node id
|
||||||
|
:param variable_key_list: variable key list
|
||||||
|
:param variable_value: variable value
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
variable_pool.add([node_id] + variable_key_list, variable_value)
|
||||||
|
|
||||||
|
# if variable_value is a dict, then recursively append variables
|
||||||
|
if isinstance(variable_value, dict):
|
||||||
|
for key, value in variable_value.items():
|
||||||
|
# construct new key list
|
||||||
|
new_key_list = variable_key_list + [key]
|
||||||
|
self._append_variables_recursively(
|
||||||
|
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
|
|||||||
@ -18,7 +18,7 @@ class DatasourceEntity(BaseModel):
|
|||||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||||
class DatasourceInput(BaseModel):
|
class DatasourceInput(BaseModel):
|
||||||
# TODO: check this type
|
# TODO: check this type
|
||||||
value: Optional[Union[Any, list[str]]] = None
|
value: Union[Any, list[str]]
|
||||||
type: Optional[Literal["mixed", "variable", "constant"]] = None
|
type: Optional[Literal["mixed", "variable", "constant"]] = None
|
||||||
|
|
||||||
@field_validator("type", mode="before")
|
@field_validator("type", mode="before")
|
||||||
|
|||||||
@ -39,15 +39,30 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
|||||||
def _run(self) -> NodeRunResult: # type: ignore
|
def _run(self) -> NodeRunResult: # type: ignore
|
||||||
node_data = cast(KnowledgeIndexNodeData, self.node_data)
|
node_data = cast(KnowledgeIndexNodeData, self.node_data)
|
||||||
variable_pool = self.graph_runtime_state.variable_pool
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||||
|
if not dataset_id:
|
||||||
|
raise KnowledgeIndexNodeError("Dataset ID is required.")
|
||||||
|
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
|
||||||
|
if not dataset:
|
||||||
|
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
|
||||||
|
|
||||||
# extract variables
|
# extract variables
|
||||||
variable = variable_pool.get(node_data.index_chunk_variable_selector)
|
variable = variable_pool.get(node_data.index_chunk_variable_selector)
|
||||||
is_preview = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) == InvokeFrom.DEBUGGER
|
if not variable:
|
||||||
|
raise KnowledgeIndexNodeError("Index chunk variable is required.")
|
||||||
|
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||||
|
if invoke_from:
|
||||||
|
is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value
|
||||||
|
else:
|
||||||
|
is_preview = False
|
||||||
chunks = variable.value
|
chunks = variable.value
|
||||||
variables = {"chunks": chunks}
|
variables = {"chunks": chunks}
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
|
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
|
||||||
)
|
)
|
||||||
|
outputs = self._get_preview_output(dataset.chunk_structure, chunks)
|
||||||
|
|
||||||
# retrieve knowledge
|
# retrieve knowledge
|
||||||
try:
|
try:
|
||||||
if is_preview:
|
if is_preview:
|
||||||
@ -55,12 +70,12 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
|||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
inputs=variables,
|
inputs=variables,
|
||||||
process_data=None,
|
process_data=None,
|
||||||
outputs={"result": "success"},
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks, variable_pool=variable_pool)
|
results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks,
|
||||||
outputs = {"result": results}
|
variable_pool=variable_pool)
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
|
||||||
)
|
)
|
||||||
|
|
||||||
except KnowledgeIndexNodeError as e:
|
except KnowledgeIndexNodeError as e:
|
||||||
@ -81,24 +96,18 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _invoke_knowledge_index(
|
def _invoke_knowledge_index(
|
||||||
self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], variable_pool: VariablePool
|
self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any],
|
||||||
|
variable_pool: VariablePool
|
||||||
) -> Any:
|
) -> Any:
|
||||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
|
||||||
if not dataset_id:
|
|
||||||
raise KnowledgeIndexNodeError("Dataset ID is required.")
|
|
||||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||||
if not document_id:
|
if not document_id:
|
||||||
raise KnowledgeIndexNodeError("Document ID is required.")
|
raise KnowledgeIndexNodeError("Document ID is required.")
|
||||||
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
|
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
|
||||||
if not batch:
|
if not batch:
|
||||||
raise KnowledgeIndexNodeError("Batch is required.")
|
raise KnowledgeIndexNodeError("Batch is required.")
|
||||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||||
if not dataset:
|
|
||||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
|
||||||
|
|
||||||
document = db.session.query(Document).filter_by(id=document_id).first()
|
|
||||||
if not document:
|
if not document:
|
||||||
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
|
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
|
||||||
|
|
||||||
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||||
index_processor.index(dataset, document, chunks)
|
index_processor.index(dataset, document, chunks)
|
||||||
@ -106,14 +115,19 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
|||||||
# update document status
|
# update document status
|
||||||
document.indexing_status = "completed"
|
document.indexing_status = "completed"
|
||||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
|
db.session.add(document)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"dataset_id": dataset.id,
|
"dataset_id": dataset.id,
|
||||||
"dataset_name": dataset.name,
|
"dataset_name": dataset.name,
|
||||||
"batch": batch,
|
"batch": batch.value,
|
||||||
"document_id": document.id,
|
"document_id": document.id,
|
||||||
"document_name": document.name,
|
"document_name": document.name,
|
||||||
"created_at": document.created_at,
|
"created_at": document.created_at.timestamp(),
|
||||||
"display_status": document.indexing_status,
|
"display_status": document.indexing_status,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||||
|
return index_processor.format_preview(chunks)
|
||||||
|
|||||||
@ -41,10 +41,9 @@ conversation_variable_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pipeline_variable_fields = {
|
pipeline_variable_fields = {
|
||||||
"id": fields.String,
|
|
||||||
"label": fields.String,
|
"label": fields.String,
|
||||||
"variable": fields.String,
|
"variable": fields.String,
|
||||||
"type": fields.String(attribute="type.value"),
|
"type": fields.String,
|
||||||
"belong_to_node_id": fields.String,
|
"belong_to_node_id": fields.String,
|
||||||
"max_length": fields.Integer,
|
"max_length": fields.Integer,
|
||||||
"required": fields.Boolean,
|
"required": fields.Boolean,
|
||||||
|
|||||||
@ -14,6 +14,8 @@ class UserFrom(StrEnum):
|
|||||||
class WorkflowRunTriggeredFrom(StrEnum):
|
class WorkflowRunTriggeredFrom(StrEnum):
|
||||||
DEBUGGING = "debugging"
|
DEBUGGING = "debugging"
|
||||||
APP_RUN = "app-run"
|
APP_RUN = "app-run"
|
||||||
|
RAG_PIPELINE_RUN = "rag-pipeline-run"
|
||||||
|
RAG_PIPELINE_DEBUGGING = "rag-pipeline-debugging"
|
||||||
|
|
||||||
|
|
||||||
class DraftVariableType(StrEnum):
|
class DraftVariableType(StrEnum):
|
||||||
|
|||||||
@ -152,6 +152,7 @@ class Workflow(Base):
|
|||||||
created_by: str,
|
created_by: str,
|
||||||
environment_variables: Sequence[Variable],
|
environment_variables: Sequence[Variable],
|
||||||
conversation_variables: Sequence[Variable],
|
conversation_variables: Sequence[Variable],
|
||||||
|
rag_pipeline_variables: list[dict],
|
||||||
marked_name: str = "",
|
marked_name: str = "",
|
||||||
marked_comment: str = "",
|
marked_comment: str = "",
|
||||||
) -> "Workflow":
|
) -> "Workflow":
|
||||||
@ -166,6 +167,7 @@ class Workflow(Base):
|
|||||||
workflow.created_by = created_by
|
workflow.created_by = created_by
|
||||||
workflow.environment_variables = environment_variables or []
|
workflow.environment_variables = environment_variables or []
|
||||||
workflow.conversation_variables = conversation_variables or []
|
workflow.conversation_variables = conversation_variables or []
|
||||||
|
workflow.rag_pipeline_variables = rag_pipeline_variables or []
|
||||||
workflow.marked_name = marked_name
|
workflow.marked_name = marked_name
|
||||||
workflow.marked_comment = marked_comment
|
workflow.marked_comment = marked_comment
|
||||||
workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
@ -340,7 +342,7 @@ class Workflow(Base):
|
|||||||
"features": self.features_dict,
|
"features": self.features_dict,
|
||||||
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
|
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
|
||||||
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
|
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
|
||||||
"rag_pipeline_variables": [var.model_dump(mode="json") for var in self.rag_pipeline_variables],
|
"rag_pipeline_variables": self.rag_pipeline_variables,
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -553,6 +555,7 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum):
|
|||||||
|
|
||||||
SINGLE_STEP = "single-step"
|
SINGLE_STEP = "single-step"
|
||||||
WORKFLOW_RUN = "workflow-run"
|
WORKFLOW_RUN = "workflow-run"
|
||||||
|
RAG_PIPELINE_RUN = "rag-pipeline-run"
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeExecutionStatus(StrEnum):
|
class WorkflowNodeExecutionStatus(StrEnum):
|
||||||
|
|||||||
@ -51,7 +51,10 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
|||||||
RetrievalModel,
|
RetrievalModel,
|
||||||
SegmentUpdateArgs,
|
SegmentUpdateArgs,
|
||||||
)
|
)
|
||||||
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
|
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||||
|
KnowledgeBaseUpdateConfiguration,
|
||||||
|
RagPipelineDatasetCreateEntity,
|
||||||
|
)
|
||||||
from services.errors.account import InvalidActionError, NoPermissionError
|
from services.errors.account import InvalidActionError, NoPermissionError
|
||||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||||
from services.errors.dataset import DatasetNameDuplicateError
|
from services.errors.dataset import DatasetNameDuplicateError
|
||||||
@ -59,11 +62,11 @@ from services.errors.document import DocumentIndexingError
|
|||||||
from services.errors.file import FileNotExistsError
|
from services.errors.file import FileNotExistsError
|
||||||
from services.external_knowledge_service import ExternalDatasetService
|
from services.external_knowledge_service import ExternalDatasetService
|
||||||
from services.feature_service import FeatureModel, FeatureService
|
from services.feature_service import FeatureModel, FeatureService
|
||||||
from services.rag_pipeline.rag_pipeline_dsl_service import ImportMode, RagPipelineDslService, RagPipelineImportInfo
|
|
||||||
from services.tag_service import TagService
|
from services.tag_service import TagService
|
||||||
from services.vector_service import VectorService
|
from services.vector_service import VectorService
|
||||||
from tasks.batch_clean_document_task import batch_clean_document_task
|
from tasks.batch_clean_document_task import batch_clean_document_task
|
||||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||||
|
from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task
|
||||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||||
@ -278,47 +281,6 @@ class DatasetService:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_rag_pipeline_dataset(
|
|
||||||
tenant_id: str,
|
|
||||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
|
||||||
):
|
|
||||||
# check if dataset name already exists
|
|
||||||
if (
|
|
||||||
db.session.query(Dataset)
|
|
||||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
|
||||||
.first()
|
|
||||||
):
|
|
||||||
raise DatasetNameDuplicateError(
|
|
||||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = Dataset(
|
|
||||||
name=rag_pipeline_dataset_create_entity.name,
|
|
||||||
description=rag_pipeline_dataset_create_entity.description,
|
|
||||||
permission=rag_pipeline_dataset_create_entity.permission,
|
|
||||||
provider="vendor",
|
|
||||||
runtime_mode="rag-pipeline",
|
|
||||||
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
|
|
||||||
)
|
|
||||||
with Session(db.engine) as session:
|
|
||||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
|
||||||
account = cast(Account, current_user)
|
|
||||||
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
|
||||||
account=account,
|
|
||||||
import_mode=ImportMode.YAML_CONTENT.value,
|
|
||||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
|
||||||
dataset=dataset,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"id": rag_pipeline_import_info.id,
|
|
||||||
"dataset_id": dataset.id,
|
|
||||||
"pipeline_id": rag_pipeline_import_info.pipeline_id,
|
|
||||||
"status": rag_pipeline_import_info.status,
|
|
||||||
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
|
|
||||||
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
|
|
||||||
"error": rag_pipeline_import_info.error,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_dataset(dataset_id) -> Optional[Dataset]:
|
def get_dataset(dataset_id) -> Optional[Dataset]:
|
||||||
@ -529,6 +491,130 @@ class DatasetService:
|
|||||||
if action:
|
if action:
|
||||||
deal_dataset_vector_index_task.delay(dataset_id, action)
|
deal_dataset_vector_index_task.delay(dataset_id, action)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_rag_pipeline_dataset_settings(session: Session,
|
||||||
|
dataset: Dataset,
|
||||||
|
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
|
||||||
|
has_published: bool = False):
|
||||||
|
if not has_published:
|
||||||
|
dataset.chunk_structure = knowledge_base_setting.chunk_structure
|
||||||
|
index_method = knowledge_base_setting.index_method
|
||||||
|
dataset.indexing_technique = index_method.indexing_technique
|
||||||
|
if index_method == "high_quality":
|
||||||
|
model_manager = ModelManager()
|
||||||
|
embedding_model = model_manager.get_model_instance(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
provider=index_method.embedding_setting.embedding_provider_name,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=index_method.embedding_setting.embedding_model_name,
|
||||||
|
)
|
||||||
|
dataset.embedding_model = embedding_model.model
|
||||||
|
dataset.embedding_model_provider = embedding_model.provider
|
||||||
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
|
embedding_model.provider, embedding_model.model
|
||||||
|
)
|
||||||
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
|
elif index_method == "economy":
|
||||||
|
dataset.keyword_number = index_method.economy_setting.keyword_number
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid index method")
|
||||||
|
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
|
||||||
|
session.add(dataset)
|
||||||
|
else:
|
||||||
|
if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure:
|
||||||
|
raise ValueError("Chunk structure is not allowed to be updated.")
|
||||||
|
action = None
|
||||||
|
if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique:
|
||||||
|
# if update indexing_technique
|
||||||
|
if knowledge_base_setting.index_method.indexing_technique == "economy":
|
||||||
|
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
|
||||||
|
elif knowledge_base_setting.index_method.indexing_technique == "high_quality":
|
||||||
|
action = "add"
|
||||||
|
# get embedding model setting
|
||||||
|
try:
|
||||||
|
model_manager = ModelManager()
|
||||||
|
embedding_model = model_manager.get_model_instance(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
|
||||||
|
)
|
||||||
|
dataset.embedding_model = embedding_model.model
|
||||||
|
dataset.embedding_model_provider = embedding_model.provider
|
||||||
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
|
embedding_model.provider, embedding_model.model
|
||||||
|
)
|
||||||
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
|
except LLMBadRequestError:
|
||||||
|
raise ValueError(
|
||||||
|
"No Embedding Model available. Please configure a valid provider "
|
||||||
|
"in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ValueError(ex.description)
|
||||||
|
else:
|
||||||
|
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
|
||||||
|
# Skip embedding model checks if not provided in the update request
|
||||||
|
if dataset.indexing_technique == "high_quality":
|
||||||
|
skip_embedding_update = False
|
||||||
|
try:
|
||||||
|
# Handle existing model provider
|
||||||
|
plugin_model_provider = dataset.embedding_model_provider
|
||||||
|
plugin_model_provider_str = None
|
||||||
|
if plugin_model_provider:
|
||||||
|
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
|
||||||
|
|
||||||
|
# Handle new model provider from request
|
||||||
|
new_plugin_model_provider = knowledge_base_setting.index_method.embedding_setting.embedding_provider_name
|
||||||
|
new_plugin_model_provider_str = None
|
||||||
|
if new_plugin_model_provider:
|
||||||
|
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
|
||||||
|
|
||||||
|
# Only update embedding model if both values are provided and different from current
|
||||||
|
if (
|
||||||
|
plugin_model_provider_str != new_plugin_model_provider_str
|
||||||
|
or knowledge_base_setting.index_method.embedding_setting.embedding_model_name != dataset.embedding_model
|
||||||
|
):
|
||||||
|
action = "update"
|
||||||
|
model_manager = ModelManager()
|
||||||
|
try:
|
||||||
|
embedding_model = model_manager.get_model_instance(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
|
||||||
|
)
|
||||||
|
except ProviderTokenNotInitError:
|
||||||
|
# If we can't get the embedding model, skip updating it
|
||||||
|
# and keep the existing settings if available
|
||||||
|
# Skip the rest of the embedding model update
|
||||||
|
skip_embedding_update = True
|
||||||
|
if not skip_embedding_update:
|
||||||
|
dataset.embedding_model = embedding_model.model
|
||||||
|
dataset.embedding_model_provider = embedding_model.provider
|
||||||
|
dataset_collection_binding = (
|
||||||
|
DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
|
embedding_model.provider, embedding_model.model
|
||||||
|
)
|
||||||
|
)
|
||||||
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
|
except LLMBadRequestError:
|
||||||
|
raise ValueError(
|
||||||
|
"No Embedding Model available. Please configure a valid provider "
|
||||||
|
"in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ValueError(ex.description)
|
||||||
|
elif dataset.indexing_technique == "economy":
|
||||||
|
if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number:
|
||||||
|
dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number
|
||||||
|
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
|
||||||
|
session.add(dataset)
|
||||||
|
session.commit()
|
||||||
|
if action:
|
||||||
|
deal_dataset_index_update_task.delay(dataset.id, action)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_dataset(dataset_id, user):
|
def delete_dataset(dataset_id, user):
|
||||||
|
|||||||
@ -4,29 +4,12 @@ from typing import Optional
|
|||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from core import datasource
|
|
||||||
from core.datasource.__base import datasource_provider
|
|
||||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
|
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
|
|
||||||
from core.model_runtime.entities.provider_entities import FormType
|
from core.model_runtime.entities.provider_entities import FormType
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
|
||||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
from core.provider_manager import ProviderManager
|
from extensions.ext_database import db
|
||||||
from models.oauth import DatasourceProvider
|
from models.oauth import DatasourceProvider
|
||||||
from models.provider import ProviderType
|
|
||||||
from services.entities.model_provider_entities import (
|
|
||||||
CustomConfigurationResponse,
|
|
||||||
CustomConfigurationStatus,
|
|
||||||
DefaultModelResponse,
|
|
||||||
ModelWithProviderEntityResponse,
|
|
||||||
ProviderResponse,
|
|
||||||
ProviderWithModelsResponse,
|
|
||||||
SimpleProviderEntityResponse,
|
|
||||||
SystemConfigurationResponse,
|
|
||||||
)
|
|
||||||
from extensions.database import db
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -115,16 +98,26 @@ class DatasourceProviderService:
|
|||||||
|
|
||||||
:param tenant_id: workspace id
|
:param tenant_id: workspace id
|
||||||
:param provider: provider name
|
:param provider: provider name
|
||||||
:param datasource_name: datasource name
|
|
||||||
:param plugin_id: plugin id
|
:param plugin_id: plugin id
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# Get all provider configurations of the current workspace
|
# Get all provider configurations of the current workspace
|
||||||
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
datasource_provider: DatasourceProvider | None = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
plugin_id=plugin_id).first()
|
plugin_id=plugin_id).first()
|
||||||
|
if not datasource_provider:
|
||||||
|
return None
|
||||||
|
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||||
|
# Get provider credential secret variables
|
||||||
|
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
|
||||||
|
|
||||||
|
# Obfuscate provider credentials
|
||||||
|
copy_credentials = encrypted_credentials.copy()
|
||||||
|
for key, value in copy_credentials.items():
|
||||||
|
if key in credential_secret_variables:
|
||||||
|
copy_credentials[key] = encrypter.obfuscated_token(value)
|
||||||
|
|
||||||
|
return copy_credentials
|
||||||
|
|
||||||
|
|
||||||
def remove_datasource_credentials(self,
|
def remove_datasource_credentials(self,
|
||||||
@ -136,11 +129,9 @@ class DatasourceProviderService:
|
|||||||
|
|
||||||
:param tenant_id: workspace id
|
:param tenant_id: workspace id
|
||||||
:param provider: provider name
|
:param provider: provider name
|
||||||
:param datasource_name: datasource name
|
|
||||||
:param plugin_id: plugin id
|
:param plugin_id: plugin id
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# Get all provider configurations of the current workspace
|
|
||||||
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
plugin_id=plugin_id).first()
|
plugin_id=plugin_id).first()
|
||||||
|
|||||||
@ -111,3 +111,12 @@ class KnowledgeConfiguration(BaseModel):
|
|||||||
chunk_structure: str
|
chunk_structure: str
|
||||||
index_method: IndexMethod
|
index_method: IndexMethod
|
||||||
retrieval_setting: RetrievalSetting
|
retrieval_setting: RetrievalSetting
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeBaseUpdateConfiguration(BaseModel):
|
||||||
|
"""
|
||||||
|
Knowledge Base Update Configuration.
|
||||||
|
"""
|
||||||
|
index_method: IndexMethod
|
||||||
|
chunk_structure: str
|
||||||
|
retrieval_setting: RetrievalSetting
|
||||||
@ -69,9 +69,9 @@ class PipelineGenerateService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
|
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||||
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
|
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
|
||||||
return WorkflowAppGenerator.convert_to_event_stream(
|
return PipelineGenerator.convert_to_event_stream(
|
||||||
WorkflowAppGenerator().single_loop_generate(
|
PipelineGenerator().single_loop_generate(
|
||||||
app_model=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -36,7 +36,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||||||
|
|
||||||
recommended_pipelines_results = []
|
recommended_pipelines_results = []
|
||||||
for pipeline_built_in_template in pipeline_built_in_templates:
|
for pipeline_built_in_template in pipeline_built_in_templates:
|
||||||
pipeline_model: Pipeline = pipeline_built_in_template.pipeline
|
pipeline_model: Pipeline | None = pipeline_built_in_template.pipeline
|
||||||
|
if not pipeline_model:
|
||||||
|
continue
|
||||||
|
|
||||||
recommended_pipeline_result = {
|
recommended_pipeline_result = {
|
||||||
"id": pipeline_built_in_template.id,
|
"id": pipeline_built_in_template.id,
|
||||||
@ -48,7 +50,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||||||
"privacy_policy": pipeline_built_in_template.privacy_policy,
|
"privacy_policy": pipeline_built_in_template.privacy_policy,
|
||||||
"position": pipeline_built_in_template.position,
|
"position": pipeline_built_in_template.position,
|
||||||
}
|
}
|
||||||
dataset: Dataset = pipeline_model.dataset
|
dataset: Dataset | None = pipeline_model.dataset
|
||||||
if dataset:
|
if dataset:
|
||||||
recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure
|
recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure
|
||||||
recommended_pipelines_results.append(recommended_pipeline_result)
|
recommended_pipelines_results.append(recommended_pipeline_result)
|
||||||
@ -72,15 +74,19 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||||||
if not pipeline_template:
|
if not pipeline_template:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# get app detail
|
# get pipeline detail
|
||||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
|
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
|
||||||
if not pipeline or not pipeline.is_public:
|
if not pipeline or not pipeline.is_public:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
dataset: Dataset | None = pipeline.dataset
|
||||||
|
if not dataset:
|
||||||
|
return None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": pipeline.id,
|
"id": pipeline.id,
|
||||||
"name": pipeline.name,
|
"name": pipeline.name,
|
||||||
"icon": pipeline.icon,
|
"icon": pipeline_template.icon,
|
||||||
"mode": pipeline.mode,
|
"chunk_structure": dataset.chunk_structure,
|
||||||
"export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline),
|
"export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -46,7 +46,8 @@ from models.workflow import (
|
|||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
WorkflowType,
|
WorkflowType,
|
||||||
)
|
)
|
||||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
from services.dataset_service import DatasetService
|
||||||
|
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity
|
||||||
from services.errors.app import WorkflowHashNotEqualError
|
from services.errors.app import WorkflowHashNotEqualError
|
||||||
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
||||||
|
|
||||||
@ -261,8 +262,7 @@ class RagPipelineService:
|
|||||||
session: Session,
|
session: Session,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
account: Account,
|
account: Account,
|
||||||
marked_name: str = "",
|
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
|
||||||
marked_comment: str = "",
|
|
||||||
) -> Workflow:
|
) -> Workflow:
|
||||||
draft_workflow_stmt = select(Workflow).where(
|
draft_workflow_stmt = select(Workflow).where(
|
||||||
Workflow.tenant_id == pipeline.tenant_id,
|
Workflow.tenant_id == pipeline.tenant_id,
|
||||||
@ -282,18 +282,25 @@ class RagPipelineService:
|
|||||||
graph=draft_workflow.graph,
|
graph=draft_workflow.graph,
|
||||||
features=draft_workflow.features,
|
features=draft_workflow.features,
|
||||||
created_by=account.id,
|
created_by=account.id,
|
||||||
environment_variables=draft_workflow.environment_variables,
|
environment_variables=draft_workflow.environment_variables,
|
||||||
conversation_variables=draft_workflow.conversation_variables,
|
conversation_variables=draft_workflow.conversation_variables,
|
||||||
marked_name=marked_name,
|
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
|
||||||
marked_comment=marked_comment,
|
marked_name="",
|
||||||
|
marked_comment="",
|
||||||
)
|
)
|
||||||
|
|
||||||
# commit db session changes
|
# commit db session changes
|
||||||
session.add(workflow)
|
session.add(workflow)
|
||||||
|
|
||||||
# trigger app workflow events TODO
|
# update dataset
|
||||||
# app_published_workflow_was_updated.send(pipeline, published_workflow=workflow)
|
dataset = pipeline.dataset
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("Dataset not found")
|
||||||
|
DatasetService.update_rag_pipeline_dataset_settings(
|
||||||
|
session=session,
|
||||||
|
dataset=dataset,
|
||||||
|
knowledge_base_setting=knowledge_base_setting,
|
||||||
|
has_published=pipeline.is_published
|
||||||
|
)
|
||||||
# return new workflow
|
# return new workflow
|
||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
|
|||||||
@ -4,13 +4,14 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Optional
|
from typing import Optional, cast
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import yaml # type: ignore
|
import yaml # type: ignore
|
||||||
from Crypto.Cipher import AES
|
from Crypto.Cipher import AES
|
||||||
from Crypto.Util.Padding import pad, unpad
|
from Crypto.Util.Padding import pad, unpad
|
||||||
|
from flask_login import current_user
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@ -31,7 +32,10 @@ from factories import variable_factory
|
|||||||
from models import Account
|
from models import Account
|
||||||
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
|
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration
|
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||||
|
KnowledgeConfiguration,
|
||||||
|
RagPipelineDatasetCreateEntity,
|
||||||
|
)
|
||||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
|
|
||||||
@ -540,9 +544,6 @@ class RagPipelineDslService:
|
|||||||
# Update existing pipeline
|
# Update existing pipeline
|
||||||
pipeline.name = pipeline_data.get("name", pipeline.name)
|
pipeline.name = pipeline_data.get("name", pipeline.name)
|
||||||
pipeline.description = pipeline_data.get("description", pipeline.description)
|
pipeline.description = pipeline_data.get("description", pipeline.description)
|
||||||
pipeline.icon_type = icon_type
|
|
||||||
pipeline.icon = icon
|
|
||||||
pipeline.icon_background = pipeline_data.get("icon_background", pipeline.icon_background)
|
|
||||||
pipeline.updated_by = account.id
|
pipeline.updated_by = account.id
|
||||||
else:
|
else:
|
||||||
if account.current_tenant_id is None:
|
if account.current_tenant_id is None:
|
||||||
@ -554,12 +555,6 @@ class RagPipelineDslService:
|
|||||||
pipeline.tenant_id = account.current_tenant_id
|
pipeline.tenant_id = account.current_tenant_id
|
||||||
pipeline.name = pipeline_data.get("name", "")
|
pipeline.name = pipeline_data.get("name", "")
|
||||||
pipeline.description = pipeline_data.get("description", "")
|
pipeline.description = pipeline_data.get("description", "")
|
||||||
pipeline.icon_type = icon_type
|
|
||||||
pipeline.icon = icon
|
|
||||||
pipeline.icon_background = pipeline_data.get("icon_background", "#FFFFFF")
|
|
||||||
pipeline.enable_site = True
|
|
||||||
pipeline.enable_api = True
|
|
||||||
pipeline.use_icon_as_answer_icon = pipeline_data.get("use_icon_as_answer_icon", False)
|
|
||||||
pipeline.created_by = account.id
|
pipeline.created_by = account.id
|
||||||
pipeline.updated_by = account.id
|
pipeline.updated_by = account.id
|
||||||
|
|
||||||
@ -674,26 +669,6 @@ class RagPipelineDslService:
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _append_model_config_export_data(cls, export_data: dict, pipeline: Pipeline) -> None:
|
|
||||||
"""
|
|
||||||
Append model config export data
|
|
||||||
:param export_data: export data
|
|
||||||
:param pipeline: Pipeline instance
|
|
||||||
"""
|
|
||||||
app_model_config = pipeline.app_model_config
|
|
||||||
if not app_model_config:
|
|
||||||
raise ValueError("Missing app configuration, please check.")
|
|
||||||
|
|
||||||
export_data["model_config"] = app_model_config.to_dict()
|
|
||||||
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
|
|
||||||
export_data["dependencies"] = [
|
|
||||||
jsonable_encoder(d.model_dump())
|
|
||||||
for d in DependenciesAnalysisService.generate_dependencies(
|
|
||||||
tenant_id=pipeline.tenant_id, dependencies=dependencies
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
|
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@ -863,3 +838,46 @@ class RagPipelineDslService:
|
|||||||
return pt.decode()
|
return pt.decode()
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_rag_pipeline_dataset(
|
||||||
|
tenant_id: str,
|
||||||
|
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||||
|
):
|
||||||
|
# check if dataset name already exists
|
||||||
|
if (
|
||||||
|
db.session.query(Dataset)
|
||||||
|
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||||
|
.first()
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = Dataset(
|
||||||
|
name=rag_pipeline_dataset_create_entity.name,
|
||||||
|
description=rag_pipeline_dataset_create_entity.description,
|
||||||
|
permission=rag_pipeline_dataset_create_entity.permission,
|
||||||
|
provider="vendor",
|
||||||
|
runtime_mode="rag-pipeline",
|
||||||
|
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
|
||||||
|
)
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||||
|
account = cast(Account, current_user)
|
||||||
|
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
||||||
|
account=account,
|
||||||
|
import_mode=ImportMode.YAML_CONTENT.value,
|
||||||
|
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||||
|
dataset=dataset,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"id": rag_pipeline_import_info.id,
|
||||||
|
"dataset_id": dataset.id,
|
||||||
|
"pipeline_id": rag_pipeline_import_info.pipeline_id,
|
||||||
|
"status": rag_pipeline_import_info.status,
|
||||||
|
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
|
||||||
|
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
|
||||||
|
"error": rag_pipeline_import_info.error,
|
||||||
|
}
|
||||||
|
|||||||
171
api/tasks/deal_dataset_index_update_task.py
Normal file
171
api/tasks/deal_dataset_index_update_task.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import click
|
||||||
|
from celery import shared_task # type: ignore
|
||||||
|
|
||||||
|
from core.rag.index_processor.constant.index_type import IndexType
|
||||||
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
|
from core.rag.models.document import ChildDocument, Document
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset, DocumentSegment
|
||||||
|
from models.dataset import Document as DatasetDocument
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(queue="dataset")
|
||||||
|
def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||||
|
"""
|
||||||
|
Async deal dataset from index
|
||||||
|
:param dataset_id: dataset_id
|
||||||
|
:param action: action
|
||||||
|
Usage: deal_dataset_index_update_task.delay(dataset_id, action)
|
||||||
|
"""
|
||||||
|
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
|
||||||
|
start_at = time.perf_counter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||||
|
|
||||||
|
if not dataset:
|
||||||
|
raise Exception("Dataset not found")
|
||||||
|
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
||||||
|
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||||
|
if action == "upgrade":
|
||||||
|
dataset_documents = (
|
||||||
|
db.session.query(DatasetDocument)
|
||||||
|
.filter(
|
||||||
|
DatasetDocument.dataset_id == dataset_id,
|
||||||
|
DatasetDocument.indexing_status == "completed",
|
||||||
|
DatasetDocument.enabled == True,
|
||||||
|
DatasetDocument.archived == False,
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
if dataset_documents:
|
||||||
|
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||||
|
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||||
|
{"indexing_status": "indexing"}, synchronize_session=False
|
||||||
|
)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
for dataset_document in dataset_documents:
|
||||||
|
try:
|
||||||
|
# add from vector index
|
||||||
|
segments = (
|
||||||
|
db.session.query(DocumentSegment)
|
||||||
|
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||||
|
.order_by(DocumentSegment.position.asc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
if segments:
|
||||||
|
documents = []
|
||||||
|
for segment in segments:
|
||||||
|
document = Document(
|
||||||
|
page_content=segment.content,
|
||||||
|
metadata={
|
||||||
|
"doc_id": segment.index_node_id,
|
||||||
|
"doc_hash": segment.index_node_hash,
|
||||||
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
documents.append(document)
|
||||||
|
# save vector index
|
||||||
|
# clean keywords
|
||||||
|
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
|
||||||
|
index_processor.load(dataset, documents, with_keywords=False)
|
||||||
|
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||||
|
{"indexing_status": "completed"}, synchronize_session=False
|
||||||
|
)
|
||||||
|
db.session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||||
|
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||||
|
)
|
||||||
|
db.session.commit()
|
||||||
|
elif action == "update":
|
||||||
|
dataset_documents = (
|
||||||
|
db.session.query(DatasetDocument)
|
||||||
|
.filter(
|
||||||
|
DatasetDocument.dataset_id == dataset_id,
|
||||||
|
DatasetDocument.indexing_status == "completed",
|
||||||
|
DatasetDocument.enabled == True,
|
||||||
|
DatasetDocument.archived == False,
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
# add new index
|
||||||
|
if dataset_documents:
|
||||||
|
# update document status
|
||||||
|
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||||
|
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||||
|
{"indexing_status": "indexing"}, synchronize_session=False
|
||||||
|
)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# clean index
|
||||||
|
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||||
|
|
||||||
|
for dataset_document in dataset_documents:
|
||||||
|
# update from vector index
|
||||||
|
try:
|
||||||
|
segments = (
|
||||||
|
db.session.query(DocumentSegment)
|
||||||
|
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||||
|
.order_by(DocumentSegment.position.asc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
if segments:
|
||||||
|
documents = []
|
||||||
|
for segment in segments:
|
||||||
|
document = Document(
|
||||||
|
page_content=segment.content,
|
||||||
|
metadata={
|
||||||
|
"doc_id": segment.index_node_id,
|
||||||
|
"doc_hash": segment.index_node_hash,
|
||||||
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||||
|
child_chunks = segment.get_child_chunks()
|
||||||
|
if child_chunks:
|
||||||
|
child_documents = []
|
||||||
|
for child_chunk in child_chunks:
|
||||||
|
child_document = ChildDocument(
|
||||||
|
page_content=child_chunk.content,
|
||||||
|
metadata={
|
||||||
|
"doc_id": child_chunk.index_node_id,
|
||||||
|
"doc_hash": child_chunk.index_node_hash,
|
||||||
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
child_documents.append(child_document)
|
||||||
|
document.children = child_documents
|
||||||
|
documents.append(document)
|
||||||
|
# save vector index
|
||||||
|
index_processor.load(dataset, documents, with_keywords=False)
|
||||||
|
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||||
|
{"indexing_status": "completed"}, synchronize_session=False
|
||||||
|
)
|
||||||
|
db.session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||||
|
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||||
|
)
|
||||||
|
db.session.commit()
|
||||||
|
else:
|
||||||
|
# clean collection
|
||||||
|
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||||
|
|
||||||
|
end_at = time.perf_counter()
|
||||||
|
logging.info(
|
||||||
|
click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Deal dataset vector index failed")
|
||||||
|
finally:
|
||||||
|
db.session.close()
|
||||||
Loading…
Reference in New Issue
Block a user