mirror of
https://github.com/langgenius/dify.git
synced 2026-04-11 03:56:55 +08:00
r2
This commit is contained in:
parent
9bafd3a226
commit
b82b26bba5
@ -39,9 +39,9 @@ from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
@ -170,7 +170,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
response = PipelineGenerateService.generate_single_iteration(
|
||||
pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
@ -207,7 +207,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
response = PipelineGenerateService.generate_single_loop(
|
||||
pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
@ -241,11 +241,12 @@ class DraftRagPipelineRunApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("datasource_info", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
response = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=args,
|
||||
@ -258,7 +259,73 @@ class DraftRagPipelineRunApi(Resource):
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
class PublishedRagPipelineRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
"""
|
||||
Run published workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("datasource_info", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
class RagPipelineDatasourceNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
result = rag_pipeline_service.run_datasource_workflow_node(
|
||||
pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class RagPipelinePublishedNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -283,7 +350,7 @@ class RagPipelineDatasourceNodeRunApi(Resource):
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_node_execution = rag_pipeline_service.run_datasource_workflow_node(
|
||||
workflow_node_execution = rag_pipeline_service.run_published_workflow_node(
|
||||
pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user
|
||||
)
|
||||
|
||||
@ -354,7 +421,8 @@ class PublishedRagPipelineApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not pipeline.is_published:
|
||||
return None
|
||||
# fetch published workflow by pipeline
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline)
|
||||
@ -397,10 +465,8 @@ class PublishedRagPipelineApi(Resource):
|
||||
marked_name=args.marked_name or "",
|
||||
marked_comment=args.marked_comment or "",
|
||||
)
|
||||
|
||||
pipeline.is_published = True
|
||||
pipeline.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
session.commit()
|
||||
@ -617,7 +683,7 @@ class RagPipelineByIdApi(Resource):
|
||||
return None, 204
|
||||
|
||||
|
||||
class RagPipelineSecondStepApi(Resource):
|
||||
class PublishedRagPipelineSecondStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -632,9 +698,28 @@ class RagPipelineSecondStepApi(Resource):
|
||||
node_id = request.args.get("node_id", required=True, type=str)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
variables = rag_pipeline_service.get_second_step_parameters(
|
||||
pipeline=pipeline, node_id=node_id
|
||||
)
|
||||
variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id)
|
||||
return {
|
||||
"variables": variables,
|
||||
}
|
||||
|
||||
|
||||
class DraftRagPipelineSecondStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
node_id = request.args.get("node_id", required=True, type=str)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id)
|
||||
return {
|
||||
"variables": variables,
|
||||
}
|
||||
@ -732,15 +817,21 @@ api.add_resource(
|
||||
RagPipelineDraftNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
# api.add_resource(
|
||||
# RagPipelinePublishedNodeRunApi,
|
||||
# "/rag/pipelines/<uuid:pipeline_id>/workflows/published/nodes/<string:node_id>/run",
|
||||
# )
|
||||
api.add_resource(
|
||||
RagPipelineDatasourceNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/datasource/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDraftRunIterationNodeApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelinePublishedNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDraftRunLoopNodeApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
@ -762,7 +853,6 @@ api.add_resource(
|
||||
DefaultRagPipelineBlockConfigApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineByIdApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>",
|
||||
@ -784,6 +874,10 @@ api.add_resource(
|
||||
"/rag/pipelines/datasource-plugins",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineSecondStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/processing/paramters",
|
||||
PublishedRagPipelineSecondStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/paramters",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftRagPipelineSecondStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/paramters",
|
||||
)
|
||||
|
||||
@ -283,7 +283,7 @@ class AppConfig(BaseModel):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
additional_features: AppAdditionalFeatures
|
||||
additional_features: Optional[AppAdditionalFeatures] = None
|
||||
variables: list[VariableEntity] = []
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
0
api/core/app/apps/pipeline/__init__.py
Normal file
0
api/core/app/apps/pipeline/__init__.py
Normal file
95
api/core/app/apps/pipeline/generate_response_converter.py
Normal file
95
api/core/app/apps/pipeline/generate_response_converter.py
Normal file
@ -0,0 +1,95 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return dict(blocking_response.to_dict())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
for chunk in stream_response:
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
for chunk in stream_response:
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
63
api/core/app/apps/pipeline/pipeline_config_manager.py
Normal file
63
api/core/app/apps/pipeline/pipeline_config_manager.py
Normal file
@ -0,0 +1,63 @@
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
|
||||
from models.dataset import Pipeline
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class PipelineConfig(WorkflowUIBasedAppConfig):
|
||||
"""
|
||||
Pipeline Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PipelineConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig:
|
||||
pipeline_config = PipelineConfig(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
workflow_id=workflow.id,
|
||||
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
|
||||
)
|
||||
|
||||
return pipeline_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||
"""
|
||||
Validate for pipeline config
|
||||
|
||||
:param tenant_id: tenant id
|
||||
:param config: app model config args
|
||||
:param only_structure_validate: only validate the structure of the config
|
||||
"""
|
||||
related_config_keys = []
|
||||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# text_to_speech
|
||||
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
496
api/core/app/apps/pipeline/pipeline_generator.py
Normal file
496
api/core/app/apps/pipeline/pipeline_generator.py
Normal file
@ -0,0 +1,496 @@
|
||||
import contextvars
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, Pipeline
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
datasource_type: str = args["datasource_type"]
|
||||
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||
|
||||
for datasource_info in datasource_info_list:
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
document_id = None
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
position = DocumentService.get_documents_position(pipeline.dataset_id)
|
||||
document = self._build_document(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
dataset_id=pipeline.dataset_id,
|
||||
built_in_field_enabled=pipeline.dataset.built_in_field_enabled,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
created_from="rag-pipeline",
|
||||
position=position,
|
||||
account=user,
|
||||
batch=batch,
|
||||
document_form=pipeline.dataset.doc_form,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
document_id = document.id
|
||||
# init application generate entity
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
pipline_config=pipeline_config,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
dataset_id=pipeline.dataset_id,
|
||||
batch=batch,
|
||||
document_id=document_id,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=pipeline_config.variables,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = PipelineQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=pipeline.mode,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": contextvars.copy_context(),
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def single_iteration_generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param node_id: the node id
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param streaming: is streamed
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param node_id: the node id
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param streaming: is streamed
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# workflow app
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _handle_response(
|
||||
self,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param workflow: workflow
|
||||
:param queue_manager: queue manager
|
||||
:param user: account or end user
|
||||
:param stream: is stream
|
||||
:param workflow_node_execution_repository: optional repository for workflow node execution
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
||||
try:
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(
|
||||
f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
def _build_document(
|
||||
self,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
built_in_field_enabled: bool,
|
||||
datasource_type: str,
|
||||
datasource_info: Mapping[str, Any],
|
||||
created_from: str,
|
||||
position: int,
|
||||
account: Account,
|
||||
batch: str,
|
||||
document_form: str,
|
||||
):
|
||||
if datasource_type == "local_file":
|
||||
name = datasource_info["name"]
|
||||
elif datasource_type == "online_document":
|
||||
name = datasource_info["page_title"]
|
||||
elif datasource_type == "website_crawl":
|
||||
name = datasource_info["title"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
document = Document(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=position,
|
||||
data_source_type=datasource_type,
|
||||
data_source_info=json.dumps(datasource_info),
|
||||
batch=batch,
|
||||
name=name,
|
||||
created_from=created_from,
|
||||
created_by=account.id,
|
||||
doc_form=document_form,
|
||||
)
|
||||
doc_metadata = {}
|
||||
if built_in_field_enabled:
|
||||
doc_metadata = {
|
||||
BuiltInField.document_name: name,
|
||||
BuiltInField.uploader: account.name,
|
||||
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
BuiltInField.source: datasource_type,
|
||||
}
|
||||
if doc_metadata:
|
||||
document.doc_metadata = doc_metadata
|
||||
return document
|
||||
44
api/core/app/apps/pipeline/pipeline_queue_manager.py
Normal file
44
api/core/app/apps/pipeline/pipeline_queue_manager.py
Normal file
@ -0,0 +1,44 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
|
||||
|
||||
class PipelineQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(
|
||||
event,
|
||||
QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent
|
||||
| QueueWorkflowPartialSuccessEvent,
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedError()
|
||||
154
api/core/app/apps/pipeline/pipeline_runner.py
Normal file
154
api/core/app/apps/pipeline/pipeline_runner.py
Normal file
@ -0,0 +1,154 @@
|
||||
import logging
|
||||
from typing import Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
RagPipelineGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Pipeline
|
||||
from models.enums import UserFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
Pipeline Application Runner
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
"""
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(PipelineConfig, app_config)
|
||||
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
|
||||
SystemVariableKey.BATCH: self.application_generate_entity.batch,
|
||||
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
@ -21,6 +21,7 @@ class InvokeFrom(Enum):
|
||||
WEB_APP = "web-app"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED = "published"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
@ -226,3 +227,37 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
inputs: dict
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
|
||||
|
||||
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||
"""
|
||||
RAG Pipeline Application Generate Entity.
|
||||
"""
|
||||
|
||||
# app config
|
||||
pipline_config: WorkflowUIBasedAppConfig
|
||||
datasource_type: str
|
||||
datasource_info: Mapping[str, Any]
|
||||
dataset_id: str
|
||||
batch: str
|
||||
document_id: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
Single Iteration Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
class SingleLoopRunEntity(BaseModel):
|
||||
"""
|
||||
Single Loop Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
|
||||
@ -1,18 +1,13 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
|
||||
|
||||
class DatasourcePlugin:
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
class DatasourcePlugin(ABC):
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
@ -20,57 +15,19 @@ class DatasourcePlugin:
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def _invoke_first_step(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: dict[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
||||
return manager.invoke_first_step(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
)
|
||||
|
||||
def _invoke_second_step(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: dict[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
||||
return manager.invoke_second_step(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
)
|
||||
@abstractmethod
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the datasource provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return DatasourcePlugin(
|
||||
entity=self.entity,
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
@ -1,26 +1,19 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class DatasourcePluginProviderController:
|
||||
class DatasourcePluginProviderController(ABC):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
def __init__(self, entity: DatasourceProviderEntityWithPlugin) -> None:
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
@ -44,29 +37,19 @@ class DatasourcePluginProviderController:
|
||||
):
|
||||
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
@abstractmethod
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return DatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
pass
|
||||
|
||||
def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore
|
||||
"""
|
||||
|
||||
@ -28,13 +28,13 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
description: I18nObject
|
||||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: ToolProviderType
|
||||
type: str
|
||||
masked_credentials: Optional[dict] = None
|
||||
original_credentials: Optional[dict] = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
|
||||
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class DatasourceProviderType(enum.StrEnum):
|
||||
|
||||
ONLINE_DOCUMENT = "online_document"
|
||||
LOCAL_FILE = "local_file"
|
||||
WEBSITE = "website"
|
||||
WEBSITE_CRAWL = "website_crawl"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "DatasourceProviderType":
|
||||
@ -111,10 +111,10 @@ class DatasourceParameter(PluginParameter):
|
||||
|
||||
|
||||
class DatasourceIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
provider: str = Field(..., description="The provider of the tool")
|
||||
author: str = Field(..., description="The author of the datasource")
|
||||
name: str = Field(..., description="The name of the datasource")
|
||||
label: I18nObject = Field(..., description="The label of the datasource")
|
||||
provider: str = Field(..., description="The provider of the datasource")
|
||||
icon: Optional[str] = None
|
||||
|
||||
|
||||
@ -145,7 +145,7 @@ class DatasourceProviderEntity(ToolProviderEntity):
|
||||
|
||||
|
||||
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
||||
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DatasourceInvokeMeta(BaseModel):
|
||||
@ -195,3 +195,105 @@ class DatasourceInvokeFrom(Enum):
|
||||
"""
|
||||
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
|
||||
|
||||
class GetOnlineDocumentPagesRequest(BaseModel):
|
||||
"""
|
||||
Get online document pages request
|
||||
"""
|
||||
|
||||
tenant_id: str = Field(..., description="The tenant id")
|
||||
|
||||
|
||||
class OnlineDocumentPageIcon(BaseModel):
|
||||
"""
|
||||
Online document page icon
|
||||
"""
|
||||
|
||||
type: str = Field(..., description="The type of the icon")
|
||||
url: str = Field(..., description="The url of the icon")
|
||||
|
||||
|
||||
class OnlineDocumentPage(BaseModel):
|
||||
"""
|
||||
Online document page
|
||||
"""
|
||||
|
||||
page_id: str = Field(..., description="The page id")
|
||||
page_title: str = Field(..., description="The page title")
|
||||
page_icon: Optional[OnlineDocumentPageIcon] = Field(None, description="The page icon")
|
||||
type: str = Field(..., description="The type of the page")
|
||||
last_edited_time: str = Field(..., description="The last edited time")
|
||||
|
||||
|
||||
class OnlineDocumentInfo(BaseModel):
|
||||
"""
|
||||
Online document info
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(..., description="The workspace id")
|
||||
workspace_name: str = Field(..., description="The workspace name")
|
||||
workspace_icon: str = Field(..., description="The workspace icon")
|
||||
total: int = Field(..., description="The total number of documents")
|
||||
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
|
||||
|
||||
|
||||
class GetOnlineDocumentPagesResponse(BaseModel):
|
||||
"""
|
||||
Get online document pages response
|
||||
"""
|
||||
|
||||
result: list[OnlineDocumentInfo]
|
||||
|
||||
|
||||
class GetOnlineDocumentPageContentRequest(BaseModel):
|
||||
"""
|
||||
Get online document page content request
|
||||
"""
|
||||
|
||||
online_document_info_list: list[OnlineDocumentInfo]
|
||||
|
||||
|
||||
class OnlineDocumentPageContent(BaseModel):
|
||||
"""
|
||||
Online document page content
|
||||
"""
|
||||
|
||||
page_id: str = Field(..., description="The page id")
|
||||
content: str = Field(..., description="The content of the page")
|
||||
|
||||
|
||||
class GetOnlineDocumentPageContentResponse(BaseModel):
|
||||
"""
|
||||
Get online document page content response
|
||||
"""
|
||||
|
||||
result: list[OnlineDocumentPageContent]
|
||||
|
||||
|
||||
class GetWebsiteCrawlRequest(BaseModel):
|
||||
"""
|
||||
Get website crawl request
|
||||
"""
|
||||
|
||||
url: str = Field(..., description="The url of the website")
|
||||
crawl_parameters: dict = Field(..., description="The crawl parameters")
|
||||
|
||||
|
||||
class WebSiteInfo(BaseModel):
|
||||
"""
|
||||
Website info
|
||||
"""
|
||||
|
||||
source_url: str = Field(..., description="The url of the website")
|
||||
markdown: str = Field(..., description="The markdown of the website")
|
||||
title: str = Field(..., description="The title of the website")
|
||||
description: str = Field(..., description="The description of the website")
|
||||
|
||||
|
||||
class GetWebsiteCrawlResponse(BaseModel):
|
||||
"""
|
||||
Get website crawl response
|
||||
"""
|
||||
|
||||
result: list[WebSiteInfo]
|
||||
|
||||
37
api/core/datasource/local_file/local_file_plugin.py
Normal file
37
api/core/datasource/local_file/local_file_plugin.py
Normal file
@ -0,0 +1,37 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
|
||||
|
||||
class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return DatasourcePlugin(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
58
api/core/datasource/local_file/local_file_provider.py
Normal file
58
api/core/datasource/local_file/local_file_provider.py
Normal file
@ -0,0 +1,58 @@
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
|
||||
|
||||
|
||||
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity)
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return LocalFileDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
@ -0,0 +1,80 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
GetOnlineDocumentPageContentResponse,
|
||||
GetOnlineDocumentPagesRequest,
|
||||
GetOnlineDocumentPagesResponse,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
|
||||
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def _get_online_document_pages(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: GetOnlineDocumentPagesRequest,
|
||||
provider_type: str,
|
||||
) -> GetOnlineDocumentPagesResponse:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_online_document_pages(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def _get_online_document_page_content(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||
provider_type: str,
|
||||
) -> GetOnlineDocumentPageContentResponse:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_online_document_page_content(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return DatasourcePlugin(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
@ -0,0 +1,50 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
|
||||
|
||||
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity)
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return DatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
63
api/core/datasource/website_crawl/website_crawl_plugin.py
Normal file
63
api/core/datasource/website_crawl/website_crawl_plugin.py
Normal file
@ -0,0 +1,63 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
GetWebsiteCrawlRequest,
|
||||
GetWebsiteCrawlResponse,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
|
||||
|
||||
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def _get_website_crawl(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: GetWebsiteCrawlRequest,
|
||||
provider_type: str,
|
||||
) -> GetWebsiteCrawlResponse:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
||||
return manager.invoke_first_step(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return DatasourcePlugin(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
50
api/core/datasource/website_crawl/website_crawl_provider.py
Normal file
50
api/core/datasource/website_crawl/website_crawl_provider.py
Normal file
@ -0,0 +1,50 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
|
||||
|
||||
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity)
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return DatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
@ -52,6 +52,7 @@ class PluginDatasourceProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
author: str
|
||||
declaration: DatasourceProviderEntityWithPlugin
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,14 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.entities.api_entities import DatasourceProviderApiEntity
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
GetOnlineDocumentPageContentResponse,
|
||||
GetOnlineDocumentPagesRequest,
|
||||
GetOnlineDocumentPagesResponse,
|
||||
GetWebsiteCrawlRequest,
|
||||
GetWebsiteCrawlResponse,
|
||||
)
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
@ -10,7 +18,7 @@ from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class PluginDatasourceManager(BasePluginClient):
|
||||
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||
def fetch_datasource_providers(self, tenant_id: str) -> list[DatasourceProviderApiEntity]:
|
||||
"""
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
@ -19,27 +27,27 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for tool in declaration.get("tools", []):
|
||||
tool["identity"]["provider"] = provider_name
|
||||
for datasource in declaration.get("datasources", []):
|
||||
datasource["identity"]["provider"] = provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasources",
|
||||
list[PluginDatasourceProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
# response = self._request_with_plugin_daemon_response(
|
||||
# "GET",
|
||||
# f"plugin/{tenant_id}/management/datasources",
|
||||
# list[PluginDatasourceProviderEntity],
|
||||
# params={"page": 1, "page_size": 256},
|
||||
# transformer=transformer,
|
||||
# )
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
# for provider in response:
|
||||
# provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for datasource in provider.declaration.datasources:
|
||||
datasource.identity.provider = provider.declaration.identity.name
|
||||
# # override the provider name for each tool to plugin_id/provider_name
|
||||
# for datasource in provider.declaration.datasources:
|
||||
# datasource.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
return [DatasourceProviderApiEntity(**self._get_local_file_datasource_provider())]
|
||||
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
|
||||
"""
|
||||
@ -71,15 +79,16 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
return response
|
||||
|
||||
def invoke_first_step(
|
||||
def get_website_crawl(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: dict[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
datasource_parameters: GetWebsiteCrawlRequest,
|
||||
provider_type: str,
|
||||
) -> GetWebsiteCrawlResponse:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
@ -88,8 +97,8 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/first_step",
|
||||
dict,
|
||||
f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_website_crawl",
|
||||
GetWebsiteCrawlResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
@ -109,15 +118,16 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
raise Exception("No response from plugin daemon")
|
||||
|
||||
def invoke_second_step(
|
||||
def get_online_document_pages(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: dict[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
datasource_parameters: GetOnlineDocumentPagesRequest,
|
||||
provider_type: str,
|
||||
) -> GetOnlineDocumentPagesResponse:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
@ -126,8 +136,47 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/second_step",
|
||||
dict,
|
||||
f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_online_document_pages",
|
||||
GetOnlineDocumentPagesResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"datasource_parameters": datasource_parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise Exception("No response from plugin daemon")
|
||||
|
||||
def get_online_document_page_content(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||
provider_type: str,
|
||||
) -> GetOnlineDocumentPageContentResponse:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_online_document_page_content",
|
||||
GetOnlineDocumentPageContentResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
@ -176,3 +225,53 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def _get_local_file_datasource_provider(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": "langgenius/file/file",
|
||||
"author": "langgenius",
|
||||
"name": "langgenius/file/file",
|
||||
"plugin_id": "langgenius/file",
|
||||
"plugin_unique_identifier": "langgenius/file:0.0.1@dify",
|
||||
"description": {
|
||||
"zh_Hans": "File",
|
||||
"en_US": "File",
|
||||
"pt_BR": "File",
|
||||
"ja_JP": "File"
|
||||
},
|
||||
"icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg",
|
||||
"label": {
|
||||
"zh_Hans": "File",
|
||||
"en_US": "File",
|
||||
"pt_BR": "File",
|
||||
"ja_JP": "File"
|
||||
},
|
||||
"type": "datasource",
|
||||
"team_credentials": {},
|
||||
"is_team_authorization": False,
|
||||
"allow_delete": True,
|
||||
"datasources": [{
|
||||
"author": "langgenius",
|
||||
"name": "upload_file",
|
||||
"label": {
|
||||
"en_US": "File",
|
||||
"zh_Hans": "File",
|
||||
"pt_BR": "File",
|
||||
"ja_JP": "File"
|
||||
},
|
||||
"description": {
|
||||
"en_US": "File",
|
||||
"zh_Hans": "File",
|
||||
"pt_BR": "File",
|
||||
"ja_JP": "File."
|
||||
},
|
||||
"parameters": [],
|
||||
"labels": [
|
||||
"search"
|
||||
],
|
||||
"output_schema": None
|
||||
}],
|
||||
"labels": [
|
||||
"search"
|
||||
]
|
||||
}
|
||||
|
||||
@ -14,3 +14,7 @@ class SystemVariableKey(StrEnum):
|
||||
APP_ID = "app_id"
|
||||
WORKFLOW_ID = "workflow_id"
|
||||
WORKFLOW_RUN_ID = "workflow_run_id"
|
||||
# RAG Pipeline
|
||||
DOCUMENT_ID = "document_id"
|
||||
BATCH = "batch"
|
||||
DATASET_ID = "dataset_id"
|
||||
|
||||
@ -3,7 +3,11 @@ from typing import Any, cast
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceParameter,
|
||||
DatasourceProviderType,
|
||||
GetWebsiteCrawlResponse,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||
from core.file import File
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
@ -77,15 +81,44 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
for_log=True,
|
||||
)
|
||||
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
# TODO: handle result
|
||||
result = datasource_runtime._invoke_second_step(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=parameters,
|
||||
)
|
||||
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
result = datasource_runtime._get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=parameters,
|
||||
provider_type=node_data.provider_type,
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"result": result.result.model_dump(),
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
)
|
||||
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
|
||||
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
|
||||
result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=parameters,
|
||||
provider_type=node_data.provider_type,
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"result": result.result.model_dump(),
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise DatasourceNodeError(
|
||||
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
|
||||
@ -155,9 +155,4 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
"""
|
||||
|
||||
type: str = "knowledge-index"
|
||||
dataset_id: str
|
||||
document_id: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
chunk_structure: Literal["general", "parent-child"]
|
||||
index_method: IndexMethod
|
||||
retrieval_setting: RetrievalSetting
|
||||
|
||||
@ -1,25 +1,19 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables.segments import ObjectSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, Document, RateLimitLog
|
||||
from models.dataset import Dataset, Document
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
@ -43,8 +37,9 @@ class KnowledgeIndexNode(LLMNode):
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = cast(KnowledgeIndexNodeData, self.node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
# extract variables
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.index_chunk_variable_selector)
|
||||
variable = variable_pool.get(node_data.index_chunk_variable_selector)
|
||||
if not isinstance(variable, ObjectSegment):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -57,34 +52,9 @@ class KnowledgeIndexNode(LLMNode):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
|
||||
)
|
||||
# check rate limit
|
||||
if self.tenant_id:
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{self.tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=self.tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
db.session.add(rate_limit_log)
|
||||
db.session.commit()
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
|
||||
error_type="RateLimitExceeded",
|
||||
)
|
||||
|
||||
# retrieve knowledge
|
||||
try:
|
||||
results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks)
|
||||
results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks, variable_pool=variable_pool)
|
||||
outputs = {"result": results}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
||||
@ -107,54 +77,26 @@ class KnowledgeIndexNode(LLMNode):
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any]) -> Any:
|
||||
dataset = Dataset.query.filter_by(id=node_data.dataset_id).first()
|
||||
def _invoke_knowledge_index(
|
||||
self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], variable_pool: VariablePool
|
||||
) -> Any:
|
||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||
if not dataset_id:
|
||||
raise KnowledgeIndexNodeError("Dataset ID is required.")
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if not document_id:
|
||||
raise KnowledgeIndexNodeError("Document ID is required.")
|
||||
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
|
||||
if not batch:
|
||||
raise KnowledgeIndexNodeError("Batch is required.")
|
||||
dataset = Dataset.query.filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.")
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
||||
|
||||
document = Document.query.filter_by(id=node_data.document_id).first()
|
||||
document = Document.query.filter_by(id=document_id).first()
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {node_data.document_id} not found.")
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
|
||||
|
||||
retrieval_setting = node_data.retrieval_setting
|
||||
index_method = node_data.index_method
|
||||
if not dataset.indexing_technique:
|
||||
if node_data.index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
raise ValueError("Indexing technique is invalid")
|
||||
|
||||
dataset.indexing_technique = index_method.indexing_technique
|
||||
if index_method.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
if (
|
||||
index_method.embedding_setting.embedding_model
|
||||
and index_method.embedding_setting.embedding_model_provider
|
||||
):
|
||||
dataset_embedding_model = index_method.embedding_setting.embedding_model
|
||||
dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider
|
||||
else:
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
dataset_embedding_model = embedding_model.model
|
||||
dataset_embedding_model_provider = embedding_model.provider
|
||||
dataset.embedding_model = dataset_embedding_model
|
||||
dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
dataset_embedding_model_provider, dataset_embedding_model
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
dataset.retrieval_model = (
|
||||
retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model
|
||||
) # type: ignore
|
||||
index_processor = IndexProcessorFactory(node_data.chunk_structure).init_index_processor()
|
||||
index_processor.index(dataset, document, chunks)
|
||||
|
||||
@ -166,6 +108,7 @@ class KnowledgeIndexNode(LLMNode):
|
||||
return {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"batch": batch,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"created_at": document.created_at,
|
||||
|
||||
@ -1,66 +0,0 @@
|
||||
METADATA_FILTER_SYSTEM_PROMPT = """
|
||||
### Job Description',
|
||||
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||
### Task
|
||||
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||
### Format
|
||||
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
""" # noqa: E501
|
||||
|
||||
METADATA_FILTER_USER_PROMPT_1 = """
|
||||
{ "input_text": "I want to know which company’s email address test@example.com is?",
|
||||
"metadata_fields": ["filename", "email", "phone", "address"]
|
||||
}
|
||||
"""
|
||||
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1 = """
|
||||
```json
|
||||
{"metadata_map": [
|
||||
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
|
||||
]
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
METADATA_FILTER_USER_PROMPT_2 = """
|
||||
{"input_text": "What are the movies with a score of more than 9 in 2024?",
|
||||
"metadata_fields": ["name", "year", "rating", "country"]}
|
||||
"""
|
||||
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_2 = """
|
||||
```json
|
||||
{"metadata_map": [
|
||||
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
|
||||
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
|
||||
]}
|
||||
```
|
||||
"""
|
||||
|
||||
METADATA_FILTER_USER_PROMPT_3 = """
|
||||
'{{"input_text": "{input_text}",',
|
||||
'"metadata_fields": {metadata_fields}}}'
|
||||
"""
|
||||
|
||||
METADATA_FILTER_COMPLETION_PROMPT = """
|
||||
### Job Description
|
||||
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||
### Task
|
||||
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||
### Format
|
||||
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
### Example
|
||||
Here is the chat example between human and assistant, inside <example></example> XML tags.
|
||||
<example>
|
||||
User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
|
||||
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
|
||||
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
|
||||
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
|
||||
</example>
|
||||
### User Input
|
||||
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
|
||||
### Assistant Output
|
||||
""" # noqa: E501
|
||||
@ -57,8 +57,6 @@ class MultipleRetrievalConfig(BaseModel):
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
|
||||
@ -39,7 +39,6 @@ from core.variables.variables import (
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
PIPELINE_VARIABLE_NODE_ID,
|
||||
)
|
||||
|
||||
|
||||
@ -123,6 +122,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
||||
result = result.model_copy(update={"selector": selector})
|
||||
return cast(Variable, result)
|
||||
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
if value is None:
|
||||
return NoneSegment()
|
||||
|
||||
@ -0,0 +1,113 @@
|
||||
"""add_pipeline_info_2
|
||||
|
||||
Revision ID: abb18a379e62
|
||||
Revises: b35c3db83d09
|
||||
Create Date: 2025-05-16 16:59:16.423127
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'abb18a379e62'
|
||||
down_revision = 'b35c3db83d09'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('component_failure_stats')
|
||||
op.drop_table('reliability_data')
|
||||
op.drop_table('maintenance')
|
||||
op.drop_table('operational_data')
|
||||
op.drop_table('component_failure')
|
||||
op.drop_table('tool_providers')
|
||||
op.drop_table('safety_data')
|
||||
op.drop_table('incident_data')
|
||||
with op.batch_alter_table('pipelines', schema=None) as batch_op:
|
||||
batch_op.drop_column('mode')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('pipelines', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('mode', sa.VARCHAR(length=255), autoincrement=False, nullable=False))
|
||||
|
||||
op.create_table('incident_data',
|
||||
sa.Column('IncidentID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column('IncidentDescription', sa.TEXT(), autoincrement=False, nullable=False),
|
||||
sa.Column('IncidentDate', sa.DATE(), autoincrement=False, nullable=False),
|
||||
sa.Column('Consequences', sa.TEXT(), autoincrement=False, nullable=True),
|
||||
sa.Column('ResponseActions', sa.TEXT(), autoincrement=False, nullable=True),
|
||||
sa.PrimaryKeyConstraint('IncidentID', name='incident_data_pkey')
|
||||
)
|
||||
op.create_table('safety_data',
|
||||
sa.Column('SafetyID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column('SafetyInspectionDate', sa.DATE(), autoincrement=False, nullable=False),
|
||||
sa.Column('SafetyFindings', sa.TEXT(), autoincrement=False, nullable=True),
|
||||
sa.Column('SafetyIncidentDescription', sa.TEXT(), autoincrement=False, nullable=True),
|
||||
sa.Column('ComplianceStatus', sa.VARCHAR(length=50), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint('SafetyID', name='safety_data_pkey')
|
||||
)
|
||||
op.create_table('tool_providers',
|
||||
sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False),
|
||||
sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True),
|
||||
sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
|
||||
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
|
||||
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
|
||||
)
|
||||
op.create_table('component_failure',
|
||||
sa.Column('FailureID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column('Date', sa.DATE(), autoincrement=False, nullable=False),
|
||||
sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||
sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||
sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||
sa.Column('RepairAction', sa.TEXT(), autoincrement=False, nullable=True),
|
||||
sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint('FailureID', name='component_failure_pkey'),
|
||||
sa.UniqueConstraint('Date', 'Component', 'FailureMode', 'Cause', 'Technician', name='unique_failure_entry')
|
||||
)
|
||||
op.create_table('operational_data',
|
||||
sa.Column('OperationID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column('CraneUsage', sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column('LoadWeight', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
|
||||
sa.Column('LoadFrequency', sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column('EnvironmentalConditions', sa.TEXT(), autoincrement=False, nullable=True),
|
||||
sa.PrimaryKeyConstraint('OperationID', name='operational_data_pkey')
|
||||
)
|
||||
op.create_table('maintenance',
|
||||
sa.Column('MaintenanceID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column('MaintenanceType', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||
sa.Column('MaintenanceDate', sa.DATE(), autoincrement=False, nullable=False),
|
||||
sa.Column('ServiceDescription', sa.TEXT(), autoincrement=False, nullable=True),
|
||||
sa.Column('PartsReplaced', sa.TEXT(), autoincrement=False, nullable=True),
|
||||
sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint('MaintenanceID', name='maintenance_pkey')
|
||||
)
|
||||
op.create_table('reliability_data',
|
||||
sa.Column('ComponentID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column('ComponentName', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||
sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
|
||||
sa.Column('FailureRate', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint('ComponentID', name='reliability_data_pkey')
|
||||
)
|
||||
op.create_table('component_failure_stats',
|
||||
sa.Column('StatID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||
sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||
sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||
sa.Column('PossibleAction', sa.TEXT(), autoincrement=False, nullable=True),
|
||||
sa.Column('Probability', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
|
||||
sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint('StatID', name='component_failure_stats_pkey')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
@ -1170,6 +1170,7 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
||||
def pipeline(self):
|
||||
return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
|
||||
|
||||
|
||||
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "pipeline_customized_templates"
|
||||
__table_args__ = (
|
||||
@ -1205,6 +1206,7 @@ class Pipeline(Base): # type: ignore[name-defined]
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()
|
||||
|
||||
@ -52,6 +52,7 @@ class AppMode(StrEnum):
|
||||
ADVANCED_CHAT = "advanced-chat"
|
||||
AGENT_CHAT = "agent-chat"
|
||||
CHANNEL = "channel"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "AppMode":
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Self, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Self, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from core.variables import utils as variable_utils
|
||||
@ -43,7 +43,7 @@ class WorkflowType(Enum):
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "WorkflowType":
|
||||
@ -370,7 +370,7 @@ class Workflow(Base):
|
||||
return results
|
||||
|
||||
@rag_pipeline_variables.setter
|
||||
def rag_pipeline_variables(self, values: List[dict]) -> None:
|
||||
def rag_pipeline_variables(self, values: list[dict]) -> None:
|
||||
self._rag_pipeline_variables = json.dumps(
|
||||
{item["variable"]: item for item in values},
|
||||
ensure_ascii=False,
|
||||
|
||||
@ -1550,7 +1550,7 @@ class DocumentService:
|
||||
@staticmethod
|
||||
def build_document(
|
||||
dataset: Dataset,
|
||||
process_rule_id: str,
|
||||
process_rule_id: str | None,
|
||||
data_source_type: str,
|
||||
document_form: str,
|
||||
document_language: str,
|
||||
|
||||
109
api/services/rag_pipeline/pipeline_generate_service.py
Normal file
109
api/services/rag_pipeline/pipeline_generate_service.py
Normal file
@ -0,0 +1,109 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.dataset import Pipeline
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
class PipelineGenerateService:
|
||||
@classmethod
|
||||
def generate(
|
||||
cls,
|
||||
pipeline: Pipeline,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
"""
|
||||
Pipeline Content Generate
|
||||
:param pipeline: pipeline
|
||||
:param user: user
|
||||
:param args: args
|
||||
:param invoke_from: invoke from
|
||||
:param streaming: streaming
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
workflow = cls._get_workflow(pipeline, invoke_from)
|
||||
return PipelineGenerator.convert_to_event_stream(
|
||||
PipelineGenerator().generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
call_depth=0,
|
||||
workflow_thread_pool_id=None,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _get_max_active_requests(app_model: App) -> int:
|
||||
max_active_requests = app_model.max_active_requests
|
||||
if max_active_requests is None:
|
||||
max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS)
|
||||
return max_active_requests
|
||||
|
||||
@classmethod
|
||||
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
|
||||
@classmethod
|
||||
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
|
||||
return WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_loop_generate(
|
||||
app_model=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow(cls, pipeline: Pipeline, invoke_from: InvokeFrom) -> Workflow:
|
||||
"""
|
||||
Get workflow
|
||||
:param pipeline: pipeline
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# fetch draft workflow by app_model
|
||||
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
else:
|
||||
# fetch published workflow by app_model
|
||||
workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not published")
|
||||
|
||||
return workflow
|
||||
@ -29,32 +29,31 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
|
||||
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = db.session.query(PipelineBuiltInTemplate).filter(
|
||||
PipelineBuiltInTemplate.language == language
|
||||
).all()
|
||||
|
||||
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = (
|
||||
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all()
|
||||
)
|
||||
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_built_in_template in pipeline_built_in_templates:
|
||||
pipeline_model: Pipeline = pipeline_built_in_template.pipeline
|
||||
|
||||
recommended_pipeline_result = {
|
||||
'id': pipeline_built_in_template.id,
|
||||
'name': pipeline_built_in_template.name,
|
||||
'pipeline_id': pipeline_model.id,
|
||||
'description': pipeline_built_in_template.description,
|
||||
'icon': pipeline_built_in_template.icon,
|
||||
'copyright': pipeline_built_in_template.copyright,
|
||||
'privacy_policy': pipeline_built_in_template.privacy_policy,
|
||||
'position': pipeline_built_in_template.position,
|
||||
"id": pipeline_built_in_template.id,
|
||||
"name": pipeline_built_in_template.name,
|
||||
"pipeline_id": pipeline_model.id,
|
||||
"description": pipeline_built_in_template.description,
|
||||
"icon": pipeline_built_in_template.icon,
|
||||
"copyright": pipeline_built_in_template.copyright,
|
||||
"privacy_policy": pipeline_built_in_template.privacy_policy,
|
||||
"position": pipeline_built_in_template.position,
|
||||
}
|
||||
dataset: Dataset = pipeline_model.dataset
|
||||
if dataset:
|
||||
recommended_pipeline_result['chunk_structure'] = dataset.chunk_structure
|
||||
recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure
|
||||
recommended_pipelines_results.append(recommended_pipeline_result)
|
||||
|
||||
return {'pipeline_templates': recommended_pipelines_results}
|
||||
|
||||
return {"pipeline_templates": recommended_pipelines_results}
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]:
|
||||
@ -64,6 +63,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:return:
|
||||
"""
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
# is in public recommended list
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first()
|
||||
|
||||
@ -3,7 +3,7 @@ import threading
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from flask_login import current_user
|
||||
@ -46,7 +46,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi
|
||||
class RagPipelineService:
|
||||
@staticmethod
|
||||
def get_pipeline_templates(
|
||||
type: Literal["built-in", "customized"] = "built-in", language: str = "en-US"
|
||||
type: str = "built-in", language: str = "en-US"
|
||||
) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]:
|
||||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
@ -358,11 +358,11 @@ class RagPipelineService:
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def run_datasource_workflow_node(
|
||||
def run_published_workflow_node(
|
||||
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run published workflow datasource
|
||||
Run published workflow node
|
||||
"""
|
||||
# fetch published workflow by app_model
|
||||
published_workflow = self.get_published_workflow(pipeline=pipeline)
|
||||
@ -393,6 +393,41 @@ class RagPipelineService:
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def run_datasource_workflow_node(
|
||||
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run published workflow datasource
|
||||
"""
|
||||
# fetch published workflow by app_model
|
||||
published_workflow = self.get_published_workflow(pipeline=pipeline)
|
||||
if not published_workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {})
|
||||
if not datasource_node_data:
|
||||
raise ValueError("Datasource node data not found")
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=datasource_node_data.get("provider_id"),
|
||||
datasource_name=datasource_node_data.get("datasource_name"),
|
||||
tenant_id=pipeline.tenant_id,
|
||||
)
|
||||
result = datasource_runtime._invoke_first_step(
|
||||
inputs=user_inputs,
|
||||
provider_type=datasource_node_data.get("provider_type"),
|
||||
user_id=account.id,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": result,
|
||||
"provider_type": datasource_node_data.get("provider_type"),
|
||||
}
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
@ -552,7 +587,7 @@ class RagPipelineService:
|
||||
|
||||
return workflow
|
||||
|
||||
def get_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
|
||||
def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
|
||||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
@ -567,9 +602,33 @@ class RagPipelineService:
|
||||
return {}
|
||||
|
||||
# get datasource provider
|
||||
datasource_provider_variables = [item for item in rag_pipeline_variables
|
||||
if item.get("belong_to_node_id") == node_id
|
||||
or item.get("belong_to_node_id") == "shared"]
|
||||
datasource_provider_variables = [
|
||||
item
|
||||
for item in rag_pipeline_variables
|
||||
if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
|
||||
]
|
||||
return datasource_provider_variables
|
||||
|
||||
def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
|
||||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
|
||||
workflow = self.get_draft_workflow(pipeline=pipeline)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
# get second step node
|
||||
rag_pipeline_variables = workflow.rag_pipeline_variables
|
||||
if not rag_pipeline_variables:
|
||||
return {}
|
||||
|
||||
# get datasource provider
|
||||
datasource_provider_variables = [
|
||||
item
|
||||
for item in rag_pipeline_variables
|
||||
if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
|
||||
]
|
||||
return datasource_provider_variables
|
||||
|
||||
def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user