Merge branch 'feat/r2' into deploy/rag-dev

# Conflicts:
#	web/app/components/workflow-app/components/workflow-main.tsx
#	web/app/components/workflow/constants.ts
#	web/app/components/workflow/header/run-and-history.tsx
#	web/app/components/workflow/hooks-store/store.ts
#	web/app/components/workflow/hooks/use-nodes-interactions.ts
#	web/app/components/workflow/hooks/use-workflow-interactions.ts
#	web/app/components/workflow/hooks/use-workflow.ts
#	web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx
#	web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx
#	web/app/components/workflow/nodes/code/use-config.ts
#	web/app/components/workflow/nodes/llm/default.ts
#	web/app/components/workflow/panel/index.tsx
#	web/app/components/workflow/panel/version-history-panel/index.tsx
#	web/app/components/workflow/store/workflow/index.ts
#	web/app/components/workflow/types.ts
#	web/config/index.ts
#	web/types/workflow.ts
This commit is contained in:
jyong 2025-07-03 11:40:54 +08:00
commit 7c5893db91
137 changed files with 11396 additions and 66 deletions

View File

@ -6,7 +6,7 @@ on:
- "main"
- "deploy/dev"
- "deploy/enterprise"
- "feat/rag-pipeline"
- "deploy/rag-dev"
tags:
- "*"

View File

@ -4,7 +4,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/dev"
- "deploy/rag-dev"
types:
- completed
@ -12,12 +12,13 @@ jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success'
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/rag-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.SSH_HOST }}
host: ${{ secrets.RAG_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@ -1,4 +1,3 @@
import os
import sys
@ -17,20 +16,20 @@ else:
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
# from gevent import monkey
#
# # gevent
# monkey.patch_all()
#
# from grpc.experimental import gevent as grpc_gevent # type: ignore
#
# # grpc gevent
# grpc_gevent.init_gevent()
# gevent
monkey.patch_all()
from grpc.experimental import gevent as grpc_gevent # type: ignore
# grpc gevent
grpc_gevent.init_gevent()
import psycogreen.gevent # type: ignore
psycogreen.gevent.patch_psycopg()
# import psycogreen.gevent # type: ignore
#
# psycogreen.gevent.patch_psycopg()
from app_factory import create_app

View File

@ -222,11 +222,28 @@ class HostedFetchAppTemplateConfig(BaseSettings):
)
class HostedFetchPipelineTemplateConfig(BaseSettings):
"""
Configuration for fetching pipeline templates
"""
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
default="database",
)
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
description="Domain for fetching remote pipeline templates",
default="https://tmpl.dify.ai",
)
class HostedServiceConfig(
# place the configs in alphabet order
HostedAnthropicConfig,
HostedAzureOpenAiConfig,
HostedFetchAppTemplateConfig,
HostedFetchPipelineTemplateConfig,
HostedMinmaxConfig,
HostedOpenAiConfig,
HostedSparkConfig,

View File

@ -3,6 +3,7 @@ from threading import Lock
from typing import TYPE_CHECKING
from contexts.wrapper import RecyclableContextVar
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
@ -33,3 +34,11 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
)
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("datasource_plugin_providers_lock")
)

View File

@ -76,7 +76,6 @@ from .billing import billing, compliance
# Import datasets controllers
from .datasets import (
data_source,
datasets,
datasets_document,
datasets_segments,
@ -85,6 +84,14 @@ from .datasets import (
metadata,
website,
)
from .datasets.rag_pipeline import (
datasource_auth,
datasource_content_preview,
rag_pipeline,
rag_pipeline_datasets,
rag_pipeline_import,
rag_pipeline_workflow,
)
# Import explore controllers
from .explore import (

View File

@ -283,6 +283,15 @@ class DatasetApi(Resource):
location="json",
help="Invalid external knowledge api id.",
)
parser.add_argument(
"icon_info",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid icon info.",
)
args = parser.parse_args()
data = request.get_json()

View File

@ -1,3 +1,4 @@
import json
import logging
from argparse import ArgumentTypeError
from datetime import UTC, datetime
@ -51,6 +52,7 @@ from fields.document_fields import (
)
from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
@ -661,7 +663,7 @@ class DocumentDetailApi(DocumentResource):
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
@ -1028,6 +1030,41 @@ class WebsiteDocumentSyncApi(DocumentResource):
return {"result": "success"}, 200
class DocumentPipelineExecutionLogApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
log = (
db.session.query(DocumentPipelineExecutionLog)
.filter_by(document_id=document_id)
.order_by(DocumentPipelineExecutionLog.created_at.desc())
.first()
)
if not log:
return {
"datasource_info": None,
"datasource_type": None,
"input_data": None,
"datasource_node_id": None,
}, 200
return {
"datasource_info": json.loads(log.datasource_info),
"datasource_type": log.datasource_type,
"input_data": log.input_data,
"datasource_node_id": log.datasource_node_id,
}, 200
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DatasetInitApi, "/datasets/init")
@ -1050,3 +1087,6 @@ api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
api.add_resource(
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
)

View File

@ -101,3 +101,9 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
error_code = "child_chunk_delete_index_error"
description = "Delete child chunk index failed: {message}"
code = 500
class PipelineNotFoundError(BaseHTTPException):
error_code = "pipeline_not_found"
description = "Pipeline not found."
code = 404

View File

@ -0,0 +1,197 @@
from flask import redirect, request
from flask_login import current_user # type: ignore
from flask_restful import ( # type: ignore
Resource, # type: ignore
reqparse,
)
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.impl.oauth import OAuthHandler
from extensions.ext_database import db
from libs.login import login_required
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from services.datasource_provider_service import DatasourceProviderService
class DatasourcePluginOauthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
provider = args["provider"]
plugin_id = args["plugin_id"]
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# get all plugin oauth configs
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if not plugin_oauth_config:
raise NotFound()
oauth_handler = OAuthHandler()
redirect_url = (
f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}"
)
system_credentials = plugin_oauth_config.system_credentials
if system_credentials:
system_credentials["redirect_url"] = redirect_url
response = oauth_handler.get_authorization_url(
current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
)
return response.model_dump()
class DatasourceOauthCallback(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
provider = args["provider"]
plugin_id = args["plugin_id"]
oauth_handler = OAuthHandler()
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if not plugin_oauth_config:
raise NotFound()
credentials = oauth_handler.get_credentials(
current_user.current_tenant.id,
current_user.id,
plugin_id,
provider,
system_credentials=plugin_oauth_config.system_credentials,
request=request,
)
datasource_provider = DatasourceProvider(
plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
)
db.session.add(datasource_provider)
db.session.commit()
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
try:
datasource_provider_service.datasource_provider_credentials_validate(
tenant_id=current_user.current_tenant_id,
provider=args["provider"],
plugin_id=args["plugin_id"],
credentials=args["credentials"],
name=args["name"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 201
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"]
)
return {"result": datasources}, 200
class DatasourceAuthUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, auth_id: str):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=auth_id,
provider=args["provider"],
plugin_id=args["plugin_id"],
)
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
def patch(self, auth_id: str):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
try:
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=auth_id,
provider=args["provider"],
plugin_id=args["plugin_id"],
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 201
# Import Rag Pipeline
api.add_resource(
DatasourcePluginOauthApi,
"/oauth/plugin/datasource",
)
api.add_resource(
DatasourceOauthCallback,
"/oauth/plugin/datasource/callback",
)
api.add_resource(
DatasourceAuth,
"/auth/plugin/datasource",
)
api.add_resource(
DatasourceAuthUpdateDeleteApi,
"/auth/plugin/datasource/<string:auth_id>",
)

View File

@ -0,0 +1,55 @@
from flask_restful import ( # type: ignore
Resource, # type: ignore
reqparse,
)
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required
from models import Account
from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService
class DataSourceContentPreviewApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
"""
Run datasource content preview
"""
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
args = parser.parse_args()
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService()
preview_content = rag_pipeline_service.run_datasource_node_preview(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
account=current_user,
datasource_type=datasource_type,
is_published=True,
)
return preview_content, 200
api.add_resource(
DataSourceContentPreviewApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
)

View File

@ -0,0 +1,162 @@
import logging
from flask import request
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
)
from extensions.ext_database import db
from libs.login import login_required
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class PipelineTemplateListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
type = request.args.get("type", default="built-in", type=str)
language = request.args.get("language", default="en-US", type=str)
# get pipeline templates
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
return pipeline_templates, 200
class PipelineTemplateDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self, template_id: str):
type = request.args.get("type", default="built-in", type=str)
rag_pipeline_service = RagPipelineService()
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
return pipeline_template, 200
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def delete(self, template_id: str):
RagPipelineService.delete_customized_pipeline_template(template_id)
return 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, template_id: str):
with Session(db.engine) as session:
template = (
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
)
if not template:
raise ValueError("Customized pipeline template not found.")
return {"data": template.yaml_content}, 200
class PublishCustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, pipeline_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
return {"result": "success"}
api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",
)
api.add_resource(
PipelineTemplateDetailApi,
"/rag/pipeline/templates/<string:template_id>",
)
api.add_resource(
CustomizedPipelineTemplateApi,
"/rag/pipeline/customized/templates/<string:template_id>",
)
api.add_resource(
PublishCustomizedPipelineTemplateApi,
"/rag/pipelines/<string:pipeline_id>/customized/publish",
)

View File

@ -0,0 +1,171 @@
from flask_login import current_user # type: ignore # type: ignore
from flask_restful import Resource, marshal, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
import services
from controllers.console import api
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class CreateRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
nullable=True,
required=False,
default={},
)
parser.add_argument(
"permission",
type=str,
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
nullable=True,
required=False,
default=DatasetPermissionEnum.ONLY_ME,
)
parser.add_argument(
"partial_member_list",
type=list,
nullable=True,
required=False,
default=[],
)
parser.add_argument(
"yaml_content",
type=str,
nullable=False,
required=True,
help="yaml_content is required.",
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
try:
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id,
import_info["dataset_id"],
rag_pipeline_dataset_create_entity.partial_member_list,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return import_info, 201
class CreateEmptyRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
nullable=True,
required=False,
default={},
)
parser.add_argument(
"permission",
type=str,
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
nullable=True,
required=False,
default=DatasetPermissionEnum.ONLY_ME,
)
parser.add_argument(
"partial_member_list",
type=list,
nullable=True,
required=False,
default=[],
)
args = parser.parse_args()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(**args),
)
return marshal(dataset, dataset_detail_fields), 201
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")

View File

@ -0,0 +1,146 @@
from typing import cast
from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from libs.login import login_required
from models import Account
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("pipeline_id", type=str, location="json")
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Import app
account = cast(Account, current_user)
result = import_service.import_rag_pipeline(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
pipeline_id=args.get("pipeline_id"),
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
class RagPipelineImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = cast(Account, current_user)
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
class RagPipelineImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
@marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
return result.model_dump(mode="json"), 200
class RagPipelineExportApi(Resource):
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=bool, default=False, location="args")
args = parser.parse_args()
with Session(db.engine) as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
return {"data": result}, 200
# Import Rag Pipeline
api.add_resource(
RagPipelineImportApi,
"/rag/pipelines/imports",
)
api.add_resource(
RagPipelineImportConfirmApi,
"/rag/pipelines/imports/<string:import_id>/confirm",
)
api.add_resource(
RagPipelineImportCheckDependenciesApi,
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
)
api.add_resource(
RagPipelineExportApi,
"/rag/pipelines/<string:pipeline_id>/exports",
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,43 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models.dataset import Pipeline
def get_rag_pipeline(
view: Optional[Callable] = None,
):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")
pipeline_id = kwargs.get("pipeline_id")
pipeline_id = str(pipeline_id)
del kwargs["pipeline_id"]
pipeline = (
db.session.query(Pipeline)
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
.first()
)
if not pipeline:
raise PipelineNotFoundError()
kwargs["pipeline"] = pipeline
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@ -113,9 +113,9 @@ class VariableEntity(BaseModel):
hide: bool = False
max_length: Optional[int] = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list)
@field_validator("description", mode="before")
@classmethod
@ -128,6 +128,16 @@ class VariableEntity(BaseModel):
return v or []
class RagPipelineVariableEntity(VariableEntity):
"""
Rag Pipeline Variable Entity.
"""
tooltips: Optional[str] = None
placeholder: Optional[str] = None
belong_to_node_id: str
class ExternalDataVariableEntity(BaseModel):
"""
External Data Variable Entity.
@ -285,7 +295,7 @@ class AppConfig(BaseModel):
tenant_id: str
app_id: str
app_mode: AppMode
additional_features: AppAdditionalFeatures
additional_features: Optional[AppAdditionalFeatures] = None
variables: list[VariableEntity] = []
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None

View File

@ -1,4 +1,4 @@
from core.app.app_config.entities import VariableEntity
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from models.workflow import Workflow
@ -20,3 +20,19 @@ class WorkflowVariablesConfigManager:
variables.append(VariableEntity.model_validate(variable))
return variables
@classmethod
def convert_rag_pipeline_variable(cls, workflow: Workflow) -> list[RagPipelineVariableEntity]:
"""
Convert workflow start variables to variables
:param workflow: workflow instance
"""
variables = []
user_input_form = workflow.rag_pipeline_user_input_form()
# variables
for variable in user_input_form:
variables.append(RagPipelineVariableEntity.model_validate(variable))
return variables

View File

@ -43,11 +43,13 @@ from core.app.entities.task_entities import (
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.datasource.entities import DatasourceNodeData
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import (
@ -183,6 +185,14 @@ class WorkflowResponseConverter:
provider_type=node_data.provider_type,
provider_id=node_data.provider_id,
)
elif event.node_type == NodeType.DATASOURCE:
node_data = cast(DatasourceNodeData, event.node_data)
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(
self._application_generate_entity.app_config.tenant_id,
f"{node_data.plugin_id}/{node_data.provider_name}",
)
response.data.extras["icon"] = provider_entity.declaration.identity.icon
return response

View File

View File

@ -0,0 +1,95 @@
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.to_dict())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
:return:
"""
return cls.convert_blocking_full_response(blocking_response)
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk

View File

@ -0,0 +1,64 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.dataset import Pipeline
from models.model import AppMode
from models.workflow import Workflow
class PipelineConfig(WorkflowUIBasedAppConfig):
"""
Pipeline Config Entity.
"""
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
pass
class PipelineConfigManager(BaseAppConfigManager):
@classmethod
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig:
pipeline_config = PipelineConfig(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id,
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow),
)
return pipeline_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
"""
Validate for pipeline config
:param tenant_id: tenant id
:param config: app model config args
:param only_structure_validate: only validate the structure of the config
"""
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@ -0,0 +1,621 @@
import contextvars
import datetime
import json
import logging
import secrets
import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.dataset_service import DocumentService
logger = logging.getLogger(__name__)
class PipelineGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
workflow=workflow,
)
# Add null check for dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
documents = []
if invoke_from == InvokeFrom.PUBLISHED:
for datasource_info in datasource_info_list:
position = DocumentService.get_documents_position(dataset.id)
document = self._build_document(
tenant_id=pipeline.tenant_id,
dataset_id=dataset.id,
built_in_field_enabled=dataset.built_in_field_enabled,
datasource_type=datasource_type,
datasource_info=datasource_info,
created_from="rag-pipeline",
position=position,
account=user,
batch=batch,
document_form=dataset.chunk_structure,
)
db.session.add(document)
documents.append(document)
db.session.commit()
# run in child thread
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = None
if invoke_from == InvokeFrom.PUBLISHED:
document_id = documents[i].id
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document_id,
datasource_type=datasource_type,
datasource_info=json.dumps(datasource_info),
datasource_node_id=start_node_id,
input_data=inputs,
pipeline_id=pipeline.id,
created_by=user.id,
)
db.session.add(document_pipeline_execution_log)
db.session.commit()
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=datasource_type,
datasource_info=datasource_info,
dataset_id=dataset.id,
start_node_id=start_node_id,
batch=batch,
document_id=document_id,
inputs=self._prepare_user_inputs(
user_inputs=inputs,
variables=pipeline_config.rag_pipeline_variables,
tenant_id=pipeline.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
files=[],
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
workflow_execution_id=workflow_run_id,
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
if invoke_from == InvokeFrom.DEBUGGER:
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
context=contextvars.copy_context(),
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
else:
# run in child thread
context = contextvars.copy_context()
worker_thread = threading.Thread(
target=self._generate,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"pipeline": pipeline,
"workflow_id": workflow.id,
"user": user,
"application_generate_entity": application_generate_entity,
"invoke_from": invoke_from,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"streaming": streaming,
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)
worker_thread.start()
# return batch, dataset, documents
return {
"batch": batch,
"dataset": PipelineDataset(
id=dataset.id,
name=dataset.name,
description=dataset.description,
chunk_structure=dataset.chunk_structure,
).model_dump(),
"documents": [
PipelineDocument(
id=document.id,
position=document.position,
data_source_type=document.data_source_type,
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
name=document.name,
indexing_status=document.indexing_status,
error=document.error,
enabled=document.enabled,
).model_dump()
for document in documents
],
}
def _generate(
self,
*,
flask_app: Flask,
context: contextvars.Context,
pipeline: Pipeline,
workflow_id: str,
user: Union[Account, EndUser],
application_generate_entity: RagPipelineGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
:param pipeline: Pipeline
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
with preserve_flask_contexts(flask_app, context_vars=context):
# init queue manager
workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first()
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
queue_manager = PipelineQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=AppMode.RAG_PIPELINE,
)
context = contextvars.copy_context()
# new thread
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"queue_manager": queue_manager,
"application_generate_entity": application_generate_entity,
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
# init application generate entity - use RagPipelineGenerateEntity instead
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
dataset_id=dataset.id,
batch=args.get("batch", ""),
document_id=args.get("document_id"),
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
def single_loop_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
batch=args.get("batch", ""),
document_id=args.get("document_id"),
dataset_id=dataset.id,
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
try:
# workflow app
runner = PipelineRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
)
runner.run()
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _handle_response(
self,
application_generate_entity: RagPipelineGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
workflow_execution_repository=workflow_execution_repository,
)
try:
return generate_task_pipeline.process()
except ValueError as e:
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(
f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}"
)
raise e
def _build_document(
self,
tenant_id: str,
dataset_id: str,
built_in_field_enabled: bool,
datasource_type: str,
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
account: Union[Account, EndUser],
batch: str,
document_form: str,
):
if datasource_type == "local_file":
name = datasource_info["name"]
elif datasource_type == "online_document":
name = datasource_info["page"]["page_name"]
elif datasource_type == "website_crawl":
name = datasource_info["title"]
else:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=position,
data_source_type=datasource_type,
data_source_info=json.dumps(datasource_info),
batch=batch,
name=name,
created_from=created_from,
created_by=account.id,
doc_form=document_form,
)
doc_metadata = {}
if built_in_field_enabled:
doc_metadata = {
BuiltInField.document_name: name,
BuiltInField.uploader: account.name,
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.source: datasource_type,
}
if doc_metadata:
document.doc_metadata = doc_metadata
return document

View File

@ -0,0 +1,44 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueErrorEvent,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
class PipelineQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
self._q.put(message)
if isinstance(
event,
QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedError()

View File

@ -0,0 +1,221 @@
import logging
from collections.abc import Mapping
from typing import Any, Optional, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.dataset import Pipeline
from models.enums import UserFrom
from models.model import EndUser
from models.workflow import Workflow, WorkflowType
logger = logging.getLogger(__name__)
class PipelineRunner(WorkflowBasedAppRunner):
"""
Pipeline Application Runner
"""
def __init__(
self,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None:
"""
Run application
"""
app_config = self.application_generate_entity.app_config
app_config = cast(PipelineConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
db.session.close()
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.APP_ID: app_config.app_id,
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
SystemVariableKey.BATCH: self.application_generate_entity.batch,
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type,
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value,
}
rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if (
rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared")
) and rag_pipeline_variable.variable in inputs:
rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,
value=inputs[rag_pipeline_variable.variable],
)
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
rag_pipeline_variables=rag_pipeline_variables,
)
# init graph
graph = self._init_rag_pipeline_graph(
graph_config=workflow.graph_dict,
start_node_id=self.application_generate_entity.start_node_id,
)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id,
)
generator = workflow_entry.run(callbacks=workflow_callbacks)
for event in generator:
self._handle_event(workflow_entry, event)
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
"""
Init pipeline graph
"""
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
nodes = graph_config.get("nodes", [])
edges = graph_config.get("edges", [])
real_run_nodes = []
real_edges = []
exclude_node_ids = []
for node in nodes:
node_id = node.get("id")
node_type = node.get("data", {}).get("type", "")
if node_type == "datasource":
if start_node_id != node_id:
exclude_node_ids.append(node_id)
continue
real_run_nodes.append(node)
for edge in edges:
if edge.get("source") in exclude_node_ids:
continue
real_edges.append(edge)
graph_config = dict(graph_config)
graph_config["nodes"] = real_run_nodes
graph_config["edges"] = real_edges
# init graph
graph = Graph.init(graph_config=graph_config)
if not graph:
raise ValueError("graph not found in workflow")
return graph

View File

@ -36,6 +36,7 @@ class InvokeFrom(Enum):
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
PUBLISHED = "published"
@classmethod
def value_of(cls, value: str):
@ -240,3 +241,38 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
inputs: dict
single_loop_run: Optional[SingleLoopRunEntity] = None
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
"""
RAG Pipeline Application Generate Entity.
"""
# pipeline config
pipeline_config: WorkflowUIBasedAppConfig
datasource_type: str
datasource_info: Mapping[str, Any]
dataset_id: str
batch: str
document_id: Optional[str] = None
start_node_id: Optional[str] = None
class SingleIterationRunEntity(BaseModel):
"""
Single Iteration Run Entity.
"""
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None
class SingleLoopRunEntity(BaseModel):
"""
Single Loop Run Entity.
"""
node_id: str
inputs: dict
single_loop_run: Optional[SingleLoopRunEntity] = None

View File

@ -105,6 +105,14 @@ class DifyAgentCallbackHandler(BaseModel):
self.current_loop += 1
def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None:
"""Run on datasource start."""
if dify_config.DEBUG:
print_text(
"\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + str(datasource_inputs) + "\n",
color=self.color,
)
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""

View File

@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
)
class DatasourcePlugin(ABC):
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
) -> None:
self.entity = entity
self.runtime = runtime
@abstractmethod
def datasource_provider_type(self) -> str:
"""
returns the type of the datasource provider
"""
return DatasourceProviderType.LOCAL_FILE
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
)

View File

@ -0,0 +1,118 @@
from abc import ABC, abstractmethod
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.entities.provider_entities import ProviderConfig
from core.plugin.impl.tool import PluginToolManager
from core.tools.errors import ToolProviderCredentialValidationError
class DatasourcePluginProviderController(ABC):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
self.entity = entity
self.tenant_id = tenant_id
@property
def need_credentials(self) -> bool:
"""
returns whether the provider needs credentials
:return: whether the provider needs credentials
"""
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
"""
manager = PluginToolManager()
if not manager.validate_datasource_credentials(
tenant_id=self.tenant_id,
user_id=user_id,
provider=self.entity.identity.name,
credentials=credentials,
):
raise ToolProviderCredentialValidationError("Invalid credentials")
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.LOCAL_FILE
@abstractmethod
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
"""
return datasource with given name
"""
pass
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""
validate the format of the credentials of the provider and set the default value if needed
:param credentials: the credentials of the tool
"""
credentials_schema = dict[str, ProviderConfig]()
if credentials_schema is None:
return
for credential in self.entity.credentials_schema:
credentials_schema[credential.name] = credential
credentials_need_to_validate: dict[str, ProviderConfig] = {}
for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
for credential_name in credentials:
if credential_name not in credentials_need_to_validate:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} not found in provider {self.entity.identity.name}"
)
# check type
credential_schema = credentials_need_to_validate[credential_name]
if not credential_schema.required and credentials[credential_name] is None:
continue
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
elif credential_schema.type == ProviderConfig.Type.SELECT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
options = credential_schema.options
if not isinstance(options, list):
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
if credentials[credential_name] not in [x.value for x in options]:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} should be one of {options}"
)
credentials_need_to_validate.pop(credential_name)
for credential_name in credentials_need_to_validate:
credential_schema = credentials_need_to_validate[credential_name]
if credential_schema.required:
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
# the credential is not set currently, set the default value if needed
if credential_schema.default is not None:
default_value = credential_schema.default
# parse default value into the correct type
if credential_schema.type in {
ProviderConfig.Type.SECRET_INPUT,
ProviderConfig.Type.TEXT_INPUT,
ProviderConfig.Type.SELECT,
}:
default_value = str(default_value)
credentials[credential_name] = default_value

View File

@ -0,0 +1,36 @@
from typing import Any, Optional
from openai import BaseModel
from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
class DatasourceRuntime(BaseModel):
"""
Meta data of a datasource call processing
"""
tenant_id: str
datasource_id: Optional[str] = None
invoke_from: Optional[InvokeFrom] = None
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
class FakeDatasourceRuntime(DatasourceRuntime):
"""
Fake datasource runtime for testing
"""
def __init__(self):
super().__init__(
tenant_id="fake_tenant_id",
datasource_id="fake_datasource_id",
invoke_from=InvokeFrom.DEBUGGER,
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
credentials={},
runtime_parameters={},
)

View File

View File

@ -0,0 +1,244 @@
import base64
import hashlib
import hmac
import logging
import os
import time
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.enums import CreatorUserRole
from models.model import MessageFile, UploadFile
from models.tools import ToolFile
logger = logging.getLogger(__name__)
class DatasourceFileManager:
@staticmethod
def sign_file(datasource_file_id: str, extension: str) -> str:
"""
sign file to get a temporary url
"""
base_url = dify_config.FILES_URL
file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@staticmethod
def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
@staticmethod
def create_file_by_raw(
*,
user_id: str,
tenant_id: str,
conversation_id: Optional[str],
file_binary: bytes,
mimetype: str,
filename: Optional[str] = None,
) -> UploadFile:
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
unique_filename = f"{unique_name}{extension}"
# default just as before
present_filename = unique_filename
if filename is not None:
has_extension = len(filename.split(".")) > 1
# Add extension flexibly
present_filename = filename if has_extension else f"{filename}{extension}"
filepath = f"datasources/{tenant_id}/{unique_filename}"
storage.save(filepath, file_binary)
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=filepath,
name=present_filename,
size=len(file_binary),
extension=extension,
mime_type=mimetype,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=user_id,
used=False,
hash=hashlib.sha3_256(file_binary).hexdigest(),
source_url="",
)
db.session.add(upload_file)
db.session.commit()
db.session.refresh(upload_file)
return upload_file
@staticmethod
def create_file_by_url(
user_id: str,
tenant_id: str,
file_url: str,
conversation_id: Optional[str] = None,
) -> UploadFile:
# try to download image
try:
response = ssrf_proxy.get(file_url)
response.raise_for_status()
blob = response.content
except httpx.TimeoutException:
raise ValueError(f"timeout when downloading file from {file_url}")
mimetype = (
guess_type(file_url)[0]
or response.headers.get("Content-Type", "").split(";")[0].strip()
or "application/octet-stream"
)
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=filepath,
name=filename,
size=len(blob),
extension=extension,
mime_type=mimetype,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=user_id,
used=False,
hash=hashlib.sha3_256(blob).hexdigest(),
source_url=file_url,
)
db.session.add(upload_file)
db.session.commit()
return upload_file
@staticmethod
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = (
db.session.query(UploadFile)
.filter(
UploadFile.id == id,
)
.first()
)
if not upload_file:
return None
blob = storage.load_once(upload_file.key)
return blob, upload_file.mime_type
@staticmethod
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
message_file: MessageFile | None = (
db.session.query(MessageFile)
.filter(
MessageFile.id == id,
)
.first()
)
# Check if message_file is not None
if message_file is not None:
# get tool file id
if message_file.url is not None:
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
else:
tool_file_id = None
else:
tool_file_id = None
tool_file: ToolFile | None = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
)
.first()
)
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
@staticmethod
def get_file_generator_by_upload_file_id(upload_file_id: str):
"""
get file binary
:param tool_file_id: the id of the tool file
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = (
db.session.query(UploadFile)
.filter(
UploadFile.id == upload_file_id,
)
.first()
)
if not upload_file:
return None, None
stream = storage.load_stream(upload_file.key)
return stream, upload_file.mime_type
# init tool_file_parser
# from core.file.datasource_file_parser import datasource_file_manager
#
# datasource_file_manager["manager"] = DatasourceFileManager

View File

@ -0,0 +1,100 @@
import logging
from threading import Lock
from typing import Union
import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.common_entities import I18nObject
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.datasource.errors import DatasourceProviderNotFoundError
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.plugin.impl.datasource import PluginDatasourceManager
logger = logging.getLogger(__name__)
class DatasourceManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_datasource_plugin_provider(
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
) -> DatasourcePluginProviderController:
"""
get the datasource plugin provider
"""
# check if context is set
try:
contexts.datasource_plugin_providers.get()
except LookupError:
contexts.datasource_plugin_providers.set({})
contexts.datasource_plugin_providers_lock.set(Lock())
with contexts.datasource_plugin_providers_lock.get():
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
if provider_id in datasource_plugin_providers:
return datasource_plugin_providers[provider_id]
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.WEBSITE_CRAWL:
controller = WebsiteCrawlDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.LOCAL_FILE:
controller = LocalFileDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
datasource_plugin_providers[provider_id] = controller
return controller
@classmethod
def get_datasource_runtime(
cls,
provider_id: str,
datasource_name: str,
tenant_id: str,
datasource_type: DatasourceProviderType,
) -> DatasourcePlugin:
"""
get the datasource runtime
:param provider_type: the type of the provider
:param provider_id: the id of the provider
:param datasource_name: the name of the datasource
:param tenant_id: the tenant id
:return: the datasource plugin
"""
return cls.get_datasource_plugin_provider(
provider_id,
tenant_id,
datasource_type,
).get_datasource(datasource_name)

View File

@ -0,0 +1,71 @@
from typing import Literal, Optional
from pydantic import BaseModel, Field, field_validator
from core.datasource.entities.datasource_entities import DatasourceParameter
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.common_entities import I18nObject
class DatasourceApiEntity(BaseModel):
author: str
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: Optional[list[DatasourceParameter]] = None
labels: list[str] = Field(default_factory=list)
output_schema: Optional[dict] = None
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
class DatasourceProviderApiEntity(BaseModel):
id: str
author: str
name: str # identifier
description: I18nObject
icon: str | dict
label: I18nObject # label
type: str
masked_credentials: Optional[dict] = None
original_credentials: Optional[dict] = None
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)
@field_validator("datasources", mode="before")
@classmethod
def convert_none_to_empty_list(cls, v):
return v if v is not None else []
def to_dict(self) -> dict:
# -------------
# overwrite datasource parameter types for temp fix
datasources = jsonable_encoder(self.datasources)
for datasource in datasources:
if datasource.get("parameters"):
for parameter in datasource.get("parameters"):
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
parameter["type"] = "files"
# -------------
return {
"id": self.id,
"author": self.author,
"name": self.name,
"plugin_id": self.plugin_id,
"plugin_unique_identifier": self.plugin_unique_identifier,
"description": self.description.to_dict(),
"icon": self.icon,
"label": self.label.to_dict(),
"type": self.type.value,
"team_credentials": self.masked_credentials,
"is_team_authorization": self.is_team_authorization,
"allow_delete": self.allow_delete,
"datasources": datasources,
"labels": self.labels,
}

View File

@ -0,0 +1,23 @@
from typing import Optional
from pydantic import BaseModel, Field
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
en_US: str
zh_Hans: Optional[str] = Field(default=None)
pt_BR: Optional[str] = Field(default=None)
ja_JP: Optional[str] = Field(default=None)
def __init__(self, **data):
super().__init__(**data)
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@ -0,0 +1,361 @@
import enum
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.oauth import OAuthSchema
from core.plugin.entities.parameters import (
PluginParameter,
PluginParameterOption,
PluginParameterType,
as_normal_type,
cast_parameter_value,
init_frontend_parameter,
)
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum
class DatasourceProviderType(enum.StrEnum):
"""
Enum class for datasource provider
"""
ONLINE_DOCUMENT = "online_document"
LOCAL_FILE = "local_file"
WEBSITE_CRAWL = "website_crawl"
ONLINE_DRIVE = "online_drive"
@classmethod
def value_of(cls, value: str) -> "DatasourceProviderType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
class DatasourceParameter(PluginParameter):
"""
Overrides type
"""
class DatasourceParameterType(enum.StrEnum):
"""
removes TOOLS_SELECTOR from PluginParameterType
"""
STRING = PluginParameterType.STRING.value
NUMBER = PluginParameterType.NUMBER.value
BOOLEAN = PluginParameterType.BOOLEAN.value
SELECT = PluginParameterType.SELECT.value
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
FILE = PluginParameterType.FILE.value
FILES = PluginParameterType.FILES.value
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
def as_normal_type(self):
return as_normal_type(self)
def cast_value(self, value: Any):
return cast_parameter_value(self, value)
type: DatasourceParameterType = Field(..., description="The type of the parameter")
description: I18nObject = Field(..., description="The description of the parameter")
@classmethod
def get_simple_instance(
cls,
name: str,
typ: DatasourceParameterType,
required: bool,
options: Optional[list[str]] = None,
) -> "DatasourceParameter":
"""
get a simple datasource parameter
:param name: the name of the parameter
:param llm_description: the description presented to the LLM
:param typ: the type of the parameter
:param required: if the parameter is required
:param options: the options of the parameter
"""
# convert options to ToolParameterOption
# FIXME fix the type error
if options:
option_objs = [
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
for option in options
]
else:
option_objs = []
return cls(
name=name,
label=I18nObject(en_US="", zh_Hans=""),
placeholder=None,
type=typ,
required=required,
options=option_objs,
description=I18nObject(en_US="", zh_Hans=""),
)
def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)
class DatasourceIdentity(BaseModel):
author: str = Field(..., description="The author of the datasource")
name: str = Field(..., description="The name of the datasource")
label: I18nObject = Field(..., description="The label of the datasource")
provider: str = Field(..., description="The provider of the datasource")
icon: Optional[str] = None
class DatasourceEntity(BaseModel):
identity: DatasourceIdentity
parameters: list[DatasourceParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The label of the datasource")
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
return v or []
class DatasourceProviderIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
tags: Optional[list[ToolLabelEnum]] = Field(
default=[],
description="The tags of the tool",
)
class DatasourceProviderEntity(BaseModel):
"""
Datasource provider entity
"""
identity: DatasourceProviderIdentity
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = None
provider_type: DatasourceProviderType
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
datasources: list[DatasourceEntity] = Field(default_factory=list)
class DatasourceInvokeMeta(BaseModel):
"""
Datasource invoke meta
"""
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: Optional[str] = None
tool_config: Optional[dict] = None
@classmethod
def empty(cls) -> "DatasourceInvokeMeta":
"""
Get an empty instance of DatasourceInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
"""
Get an instance of DatasourceInvokeMeta with error
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self) -> dict:
return {
"time_cost": self.time_cost,
"error": self.error,
"tool_config": self.tool_config,
}
class DatasourceLabel(BaseModel):
"""
Datasource label
"""
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
icon: str = Field(..., description="The icon of the tool")
class DatasourceInvokeFrom(Enum):
"""
Enum class for datasource invoke
"""
RAG_PIPELINE = "rag_pipeline"
class OnlineDocumentPage(BaseModel):
"""
Online document page
"""
page_id: str = Field(..., description="The page id")
page_name: str = Field(..., description="The page title")
page_icon: Optional[dict] = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time")
parent_id: Optional[str] = Field(None, description="The parent page id")
class OnlineDocumentInfo(BaseModel):
"""
Online document info
"""
workspace_id: str = Field(..., description="The workspace id")
workspace_name: str = Field(..., description="The workspace name")
workspace_icon: str = Field(..., description="The workspace icon")
total: int = Field(..., description="The total number of documents")
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
class OnlineDocumentPagesMessage(BaseModel):
"""
Get online document pages response
"""
result: list[OnlineDocumentInfo]
class GetOnlineDocumentPageContentRequest(BaseModel):
"""
Get online document page content request
"""
workspace_id: str = Field(..., description="The workspace id")
page_id: str = Field(..., description="The page id")
type: str = Field(..., description="The type of the page")
class OnlineDocumentPageContent(BaseModel):
"""
Online document page content
"""
workspace_id: str = Field(..., description="The workspace id")
page_id: str = Field(..., description="The page id")
content: str = Field(..., description="The content of the page")
class GetOnlineDocumentPageContentResponse(BaseModel):
"""
Get online document page content response
"""
result: OnlineDocumentPageContent
class GetWebsiteCrawlRequest(BaseModel):
"""
Get website crawl request
"""
crawl_parameters: dict = Field(..., description="The crawl parameters")
class WebSiteInfoDetail(BaseModel):
source_url: str = Field(..., description="The url of the website")
content: str = Field(..., description="The content of the website")
title: str = Field(..., description="The title of the website")
description: str = Field(..., description="The description of the website")
class WebSiteInfo(BaseModel):
"""
Website info
"""
status: Optional[str] = Field(..., description="crawl job status")
web_info_list: Optional[list[WebSiteInfoDetail]] = []
total: Optional[int] = Field(default=0, description="The total number of websites")
completed: Optional[int] = Field(default=0, description="The number of completed websites")
class WebsiteCrawlMessage(BaseModel):
"""
Get website crawl response
"""
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
class DatasourceMessage(ToolInvokeMessage):
pass
#########################
# Online driver file
#########################
class OnlineDriveFile(BaseModel):
"""
Online driver file
"""
key: str = Field(..., description="The key of the file")
size: int = Field(..., description="The size of the file")
class OnlineDriveFileBucket(BaseModel):
"""
Online driver file bucket
"""
bucket: Optional[str] = Field(None, description="The bucket of the file")
files: list[OnlineDriveFile] = Field(..., description="The files of the bucket")
is_truncated: bool = Field(False, description="Whether the bucket has more files")
class OnlineDriveBrowseFilesRequest(BaseModel):
"""
Get online driver file list request
"""
prefix: Optional[str] = Field(None, description="File path prefix for filtering eg: 'docs/dify/'")
bucket: Optional[str] = Field(None, description="Storage bucket name")
max_keys: int = Field(20, description="Maximum number of files to return")
start_after: Optional[str] = Field(
None, description="Pagination token for continuing from a specific file eg: 'docs/dify/1.txt'"
)
class OnlineDriveBrowseFilesResponse(BaseModel):
"""
Get online driver file list response
"""
result: list[OnlineDriveFileBucket] = Field(..., description="The bucket of the files")
class OnlineDriveDownloadFileRequest(BaseModel):
"""
Get online driver file
"""
key: str = Field(..., description="The name of the file")
bucket: Optional[str] = Field(None, description="The name of the bucket")

View File

@ -0,0 +1,37 @@
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
class DatasourceProviderNotFoundError(ValueError):
pass
class DatasourceNotFoundError(ValueError):
pass
class DatasourceParameterValidationError(ValueError):
pass
class DatasourceProviderCredentialValidationError(ValueError):
pass
class DatasourceNotSupportedError(ValueError):
pass
class DatasourceInvokeError(ValueError):
pass
class DatasourceApiSchemaError(ValueError):
pass
class DatasourceEngineInvokeError(Exception):
meta: DatasourceInvokeMeta
def __init__(self, meta, **kwargs):
self.meta = meta
super().__init__(**kwargs)

View File

@ -0,0 +1,28 @@
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
)
class LocalFileDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def datasource_provider_type(self) -> str:
return DatasourceProviderType.LOCAL_FILE

View File

@ -0,0 +1,56 @@
from typing import Any
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.LOCAL_FILE
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
"""
pass
def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return LocalFileDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -0,0 +1,73 @@
from collections.abc import Generator, Mapping
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceMessage,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDocumentPagesMessage,
)
from core.plugin.impl.datasource import PluginDatasourceManager
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def get_online_document_pages(
self,
user_id: str,
datasource_parameters: Mapping[str, Any],
provider_type: str,
) -> Generator[OnlineDocumentPagesMessage, None, None]:
manager = PluginDatasourceManager()
return manager.get_online_document_pages(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def get_online_document_page_content(
self,
user_id: str,
datasource_parameters: GetOnlineDocumentPageContentRequest,
provider_type: str,
) -> Generator[DatasourceMessage, None, None]:
manager = PluginDatasourceManager()
return manager.get_online_document_page_content(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DOCUMENT

View File

@ -0,0 +1,48 @@
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.ONLINE_DOCUMENT
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return OnlineDocumentDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -0,0 +1,73 @@
from collections.abc import Generator
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceMessage,
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
OnlineDriveBrowseFilesResponse,
OnlineDriveDownloadFileRequest,
)
from core.plugin.impl.datasource import PluginDatasourceManager
class OnlineDriveDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def online_drive_browse_files(
self,
user_id: str,
request: OnlineDriveBrowseFilesRequest,
provider_type: str,
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
manager = PluginDatasourceManager()
return manager.online_drive_browse_files(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
request=request,
provider_type=provider_type,
)
def online_drive_download_file(
self,
user_id: str,
request: OnlineDriveDownloadFileRequest,
provider_type: str,
) -> Generator[DatasourceMessage, None, None]:
manager = PluginDatasourceManager()
return manager.online_drive_download_file(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
request=request,
provider_type=provider_type,
)
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DRIVE

View File

@ -0,0 +1,48 @@
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.ONLINE_DRIVE
def get_datasource(self, datasource_name: str) -> OnlineDriveDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return OnlineDriveDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

View File

@ -0,0 +1,265 @@
from copy import deepcopy
from typing import Any
from pydantic import BaseModel
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderType,
)
class ProviderConfigEncrypter(BaseModel):
tenant_id: str
config: list[BasicProviderConfig]
provider_type: str
provider_identity: str
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
cache.set(data)
return data
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cache.delete()
class ToolParameterConfigurationManager:
"""
Tool parameter configuration manager
"""
tenant_id: str
tool_runtime: Tool
provider_name: str
provider_type: ToolProviderType
identity_id: str
def __init__(
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
) -> None:
self.tenant_id = tenant_id
self.tool_runtime = tool_runtime
self.provider_name = provider_name
self.provider_type = provider_type
self.identity_id = identity_id
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
deep copy parameters
"""
return deepcopy(parameters)
def _merge_parameters(self) -> list[ToolParameter]:
"""
merge parameters
"""
# get tool parameters
tool_parameters = self.tool_runtime.entity.parameters or []
# get tool runtime parameters
runtime_parameters = self.tool_runtime.get_runtime_parameters()
# override parameters
current_parameters = tool_parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
return current_parameters
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
mask tool parameters
return a deep copy of parameters with masked values
"""
parameters = self._deep_copy(parameters)
# override parameters
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = (
parameters[parameter.name][:2]
+ "*" * (len(parameters[parameter.name]) - 4)
+ parameters[parameter.name][-2:]
)
else:
parameters[parameter.name] = "*" * len(parameters[parameter.name])
return parameters
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
encrypt tool parameters with tenant id
return a deep copy of parameters with encrypted values
"""
# override parameters
current_parameters = self._merge_parameters()
parameters = self._deep_copy(parameters)
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted
return parameters
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
decrypt tool parameters with tenant id
return a deep copy of parameters with decrypted values
"""
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
)
cached_parameters = cache.get()
if cached_parameters:
return cached_parameters
# override parameters
current_parameters = self._merge_parameters()
has_secret_input = False
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
try:
has_secret_input = True
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except Exception:
pass
if has_secret_input:
cache.set(parameters)
return parameters
def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
)
cache.delete()

View File

@ -0,0 +1,121 @@
import logging
from collections.abc import Generator
from mimetypes import guess_extension
from typing import Optional
from core.datasource.datasource_file_manager import DatasourceFileManager
from core.datasource.entities.datasource_entities import DatasourceMessage
from core.file import File, FileTransferMethod, FileType
logger = logging.getLogger(__name__)
class DatasourceFileMessageTransformer:
@classmethod
def transform_datasource_invoke_messages(
cls,
messages: Generator[DatasourceMessage, None, None],
user_id: str,
tenant_id: str,
conversation_id: Optional[str] = None,
) -> Generator[DatasourceMessage, None, None]:
"""
Transform datasource message and handle file download
"""
for message in messages:
if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}:
yield message
elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance(
message.message, DatasourceMessage.TextMessage
):
# try to download image
try:
assert isinstance(message.message, DatasourceMessage.TextMessage)
file = DatasourceFileManager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=message.message.text,
conversation_id=conversation_id,
)
url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}"
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {},
)
except Exception as e:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.TEXT,
message=DatasourceMessage.TextMessage(
text=f"Failed to download image: {message.message.text}: {e}"
),
meta=message.meta.copy() if message.meta is not None else {},
)
elif message.type == DatasourceMessage.MessageType.BLOB:
# get mime type and save blob to storage
meta = message.meta or {}
mimetype = meta.get("mime_type", "application/octet-stream")
# get filename from meta
filename = meta.get("file_name", None)
# if message is str, encode it to bytes
if not isinstance(message.message, DatasourceMessage.BlobMessage):
raise ValueError("unexpected message type")
# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
file = DatasourceFileManager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_binary=message.message.blob,
mimetype=mimetype,
filename=filename,
)
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type))
# check if file is image
if "image" in mimetype:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.BINARY_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
elif message.type == DatasourceMessage.MessageType.FILE:
meta = message.meta or {}
file = meta.get("file", None)
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield message
else:
yield message
@classmethod
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"

View File

@ -0,0 +1,389 @@
import re
import uuid
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Optional
from flask import request
from requests import get
from yaml import YAMLError, safe_load # type: ignore
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
# set description to extra_info
extra_info["description"] = openapi["info"].get("description", "")
if len(openapi["servers"]) == 0:
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
server_url = openapi["servers"][0]["url"]
request_env = request.headers.get("X-Request-Env")
if request_env:
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
server_url = matched_servers[0] if matched_servers else server_url
# list all interfaces
interfaces = []
for path, path_item in openapi["paths"].items():
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
for method in methods:
if method in path_item:
interfaces.append(
{
"path": path,
"method": method,
"operation": path_item[method],
}
)
# get all parameters
bundles = []
for interface in interfaces:
# convert parameters
parameters = []
if "parameters" in interface["operation"]:
for parameter in interface["operation"]["parameters"]:
tool_parameter = ToolParameter(
name=parameter["name"],
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
human_description=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=parameter.get("required", False),
form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get("description"),
default=parameter["schema"]["default"]
if "schema" in parameter and "default" in parameter["schema"]
else None,
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
if typ:
tool_parameter.type = typ
parameters.append(tool_parameter)
# create tool bundle
# check if there is a request body
if "requestBody" in interface["operation"]:
request_body = interface["operation"]["requestBody"]
if "content" in request_body:
for content_type, content in request_body["content"].items():
# if there is a reference, get the reference and overwrite the content
if "schema" not in content:
continue
if "$ref" in content["schema"]:
# get the reference
root = openapi
reference = content["schema"]["$ref"].split("/")[1:]
for ref in reference:
root = root[ref]
# overwrite the content
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
# parse body parameters
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
tool = ToolParameter(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
human_description=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=name in required,
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
placeholder=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ:
tool.type = typ
parameters.append(tool)
# check if parameters is duplicated
parameters_count = {}
for parameter in parameters:
if parameter.name not in parameters_count:
parameters_count[parameter.name] = 0
parameters_count[parameter.name] += 1
for name, count in parameters_count.items():
if count > 1:
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
# check if there is a operation id, use $path_$method as operation id if not
if "operationId" not in interface["operation"]:
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = interface["path"]
if interface["path"].startswith("/"):
path = interface["path"][1:]
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
if not path:
path = str(uuid.uuid4())
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
bundles.append(
ApiToolBundle(
server_url=server_url + interface["path"],
method=interface["method"],
summary=interface["operation"]["description"]
if "description" in interface["operation"]
else interface["operation"].get("summary", None),
operation_id=interface["operation"]["operationId"],
parameters=parameters,
author="",
icon=None,
openapi=interface["operation"],
)
)
return bundles
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
parameter = parameter or {}
typ: Optional[str] = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE
if "type" in parameter:
typ = parameter["type"]
elif "schema" in parameter and "type" in parameter["schema"]:
typ = parameter["schema"]["type"]
if typ in {"integer", "number"}:
return ToolParameter.ToolParameterType.NUMBER
elif typ == "boolean":
return ToolParameter.ToolParameterType.BOOLEAN
elif typ == "string":
return ToolParameter.ToolParameterType.STRING
elif typ == "array":
items = parameter.get("items") or parameter.get("schema", {}).get("items")
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
else:
return None
@staticmethod
def parse_openapi_yaml_to_tool_bundle(
yaml: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
:param yaml: the yaml string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
openapi: dict = safe_load(yaml)
if openapi is None:
raise ToolApiSchemaError("Invalid openapi yaml.")
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
warning = warning or {}
"""
parse swagger to openapi
:param swagger: the swagger dict
:return: the openapi dict
"""
# convert swagger to openapi
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
servers = swagger.get("servers", [])
if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.")
openapi = {
"openapi": "3.0.0",
"info": {
"title": info.get("title", "Swagger"),
"description": info.get("description", "Swagger"),
"version": info.get("version", "1.0.0"),
},
"servers": swagger["servers"],
"paths": {},
"components": {"schemas": {}},
}
# check paths
if "paths" not in swagger or len(swagger["paths"]) == 0:
raise ToolApiSchemaError("No paths found in the swagger yaml.")
# convert paths
for path, path_item in swagger["paths"].items():
openapi["paths"][path] = {}
for method, operation in path_item.items():
if "operationId" not in operation:
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
if ("summary" not in operation or len(operation["summary"]) == 0) and (
"description" not in operation or len(operation["description"]) == 0
):
if warning is not None:
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
openapi["paths"][path][method] = {
"operationId": operation["operationId"],
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": operation.get("parameters", []),
"responses": operation.get("responses", {}),
}
if "requestBody" in operation:
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
# convert definitions
for name, definition in swagger["definitions"].items():
openapi["components"]["schemas"][name] = definition
return openapi
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(
json: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
:param json: the json string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
try:
openai_plugin = json_loads(json)
api = openai_plugin["api"]
api_url = api["url"]
api_type = api["type"]
except JSONDecodeError:
raise ToolProviderNotFoundError("Invalid openai plugin json.")
if api_type != "openapi":
raise ToolNotSupportedError("Only openapi is supported now.")
# get openapi yaml
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
if response.status_code != 200:
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
response.text, extra_info=extra_info, warning=warning
)
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle
:param content: the content
:param extra_info: the extra info
:param warning: the warning message
:return: tools bundle, schema_type
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
content = content.strip()
loaded_content = None
json_error = None
yaml_error = None
try:
loaded_content = json_loads(content)
except JSONDecodeError as e:
json_error = e
if loaded_content is None:
try:
loaded_content = safe_load(content)
except YAMLError as e:
yaml_error = e
if loaded_content is None:
raise ToolApiSchemaError(
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
f" yaml error: {str(yaml_error)}"
)
swagger_error = None
openapi_error = None
openapi_plugin_error = None
schema_type = None
try:
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.OPENAPI.value
return openapi, schema_type
except ToolApiSchemaError as e:
openapi_error = e
# openai parse error, fallback to swagger
try:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.SWAGGER.value
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
converted_swagger, extra_info=extra_info, warning=warning
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
json_dumps(loaded_content), extra_info=extra_info, warning=warning
)
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
except ToolNotSupportedError as e:
# maybe it's not plugin at all
openapi_plugin_error = e
raise ToolApiSchemaError(
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
f" openapi plugin error: {str(openapi_plugin_error)}"
)

View File

@ -0,0 +1,17 @@
import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
Args:
text (str): The input text to process.
Returns:
str: The text with leading punctuation or symbols removed.
"""
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
return re.sub(pattern, "", text)

View File

@ -0,0 +1,9 @@
import uuid
def is_valid_uuid(uuid_str: str) -> bool:
try:
uuid.UUID(uuid_str)
return True
except Exception:
return False

View File

@ -0,0 +1,43 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""
get workflow graph variables
"""
nodes = graph.get("nodes", [])
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
if not start_node:
return []
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
):
"""
check is synced
raise ValueError if not synced
"""
variable_names = [variable.variable for variable in variables]
if len(tool_configurations) != len(variables):
raise ValueError("parameter configuration mismatch, please republish the tool to update")
for parameter in tool_configurations:
if parameter.name not in variable_names:
raise ValueError("parameter configuration mismatch, please republish the tool to update")

View File

@ -0,0 +1,35 @@
import logging
from pathlib import Path
from typing import Any
import yaml # type: ignore
from yaml import YAMLError
logger = logging.getLogger(__name__)
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
"""
Safe loading a YAML file
:param file_path: the path of the YAML file
:param ignore_error:
if True, return default_value if error occurs and the error will be logged in debug level
if False, raise error if error occurs
:param default_value: the value returned when errors ignored
:return: an object of the YAML content
"""
if not file_path or not Path(file_path).exists():
if ignore_error:
return default_value
else:
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content or default_value
except Exception as e:
if ignore_error:
return default_value
else:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e

View File

@ -0,0 +1,53 @@
from collections.abc import Generator, Mapping
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
WebsiteCrawlMessage,
)
from core.plugin.impl.datasource import PluginDatasourceManager
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def get_website_crawl(
self,
user_id: str,
datasource_parameters: Mapping[str, Any],
provider_type: str,
) -> Generator[WebsiteCrawlMessage, None, None]:
manager = PluginDatasourceManager()
return manager.get_website_crawl(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def datasource_provider_type(self) -> str:
return DatasourceProviderType.WEBSITE_CRAWL

View File

@ -0,0 +1,52 @@
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self,
entity: DatasourceProviderEntityWithPlugin,
plugin_id: str,
plugin_unique_identifier: str,
tenant_id: str,
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.WEBSITE_CRAWL
def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return WebsiteCrawlDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -17,3 +17,27 @@ class IndexingEstimate(BaseModel):
total_segments: int
preview: list[PreviewDetail]
qa_preview: Optional[list[QAPreviewDetail]] = None
class PipelineDataset(BaseModel):
id: str
name: str
description: str
chunk_structure: str
class PipelineDocument(BaseModel):
id: str
position: int
data_source_type: str
data_source_info: Optional[dict] = None
name: str
indexing_status: str
error: Optional[str] = None
enabled: bool
class PipelineGenerateResponse(BaseModel):
batch: str
dataset: PipelineDataset
documents: list[PipelineDocument]

View File

@ -0,0 +1,15 @@
from typing import TYPE_CHECKING, Any, cast
from core.datasource import datasource_file_manager
from core.datasource.datasource_file_manager import DatasourceFileManager
if TYPE_CHECKING:
from core.datasource.datasource_file_manager import DatasourceFileManager
tool_file_manager: dict[str, Any] = {"manager": None}
class DatasourceFileParser:
@staticmethod
def get_datasource_file_manager() -> "DatasourceFileManager":
return cast("DatasourceFileManager", datasource_file_manager["manager"])

View File

@ -20,6 +20,7 @@ class FileTransferMethod(StrEnum):
REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file"
DATASOURCE_FILE = "datasource_file"
@staticmethod
def value_of(value):

View File

@ -135,3 +135,4 @@ class TraceTaskName(StrEnum):
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
TOOL_TRACE = "tool"
GENERATE_NAME_TRACE = "generate_conversation_name"
DATASOURCE_TRACE = "datasource"

View File

@ -0,0 +1,21 @@
from collections.abc import Sequence
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
class OAuthSchema(BaseModel):
"""
OAuth schema
"""
client_schema: Sequence[ProviderConfig] = Field(
default_factory=list,
description="client schema like client_id, client_secret, etc.",
)
credentials_schema: Sequence[ProviderConfig] = Field(
default_factory=list,
description="credentials schema like access_token, refresh_token, etc.",
)

View File

@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, model_validator
from werkzeug.exceptions import NotFound
from core.agent.plugin_entities import AgentStrategyProviderEntity
from core.datasource.entities.datasource_entities import DatasourceProviderEntity
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.plugin.entities.base import BasePluginEntity
from core.plugin.entities.endpoint import EndpointProviderDeclaration
@ -62,6 +63,7 @@ class PluginCategory(enum.StrEnum):
Model = "model"
Extension = "extension"
AgentStrategy = "agent-strategy"
Datasource = "datasource"
class PluginDeclaration(BaseModel):
@ -69,6 +71,7 @@ class PluginDeclaration(BaseModel):
tools: Optional[list[str]] = Field(default_factory=list[str])
models: Optional[list[str]] = Field(default_factory=list[str])
endpoints: Optional[list[str]] = Field(default_factory=list[str])
datasources: Optional[list[str]] = Field(default_factory=list[str])
class Meta(BaseModel):
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
@ -90,6 +93,7 @@ class PluginDeclaration(BaseModel):
model: Optional[ProviderEntity] = None
endpoint: Optional[EndpointProviderDeclaration] = None
agent_strategy: Optional[AgentStrategyProviderEntity] = None
datasource: Optional[DatasourceProviderEntity] = None
meta: Meta
@model_validator(mode="before")
@ -100,6 +104,8 @@ class PluginDeclaration(BaseModel):
values["category"] = PluginCategory.Tool
elif values.get("model"):
values["category"] = PluginCategory.Model
elif values.get("datasource"):
values["category"] = PluginCategory.Datasource
elif values.get("agent_strategy"):
values["category"] = PluginCategory.AgentStrategy
else:
@ -193,6 +199,11 @@ class ToolProviderID(GenericProviderID):
self.plugin_name = f"{self.provider_name}_tool"
class DatasourceProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
super().__init__(value, is_hardcoded)
class PluginDependency(BaseModel):
class Type(enum.StrEnum):
Github = PluginInstallationSource.Github.value

View File

@ -6,6 +6,7 @@ from typing import Any, Generic, Optional, TypeVar
from pydantic import BaseModel, ConfigDict, Field
from core.agent.plugin_entities import AgentProviderEntityWithPlugin
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.plugin.entities.base import BasePluginEntity
@ -48,6 +49,14 @@ class PluginToolProviderEntity(BaseModel):
declaration: ToolProviderEntityWithPlugin
class PluginDatasourceProviderEntity(BaseModel):
provider: str
plugin_unique_identifier: str
plugin_id: str
is_authorized: bool = False
declaration: DatasourceProviderEntityWithPlugin
class PluginAgentProviderEntity(BaseModel):
provider: str
plugin_unique_identifier: str

View File

@ -0,0 +1,329 @@
from collections.abc import Generator, Mapping
from typing import Any
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
GetOnlineDocumentPageContentRequest,
OnlineDocumentPagesMessage,
OnlineDriveBrowseFilesRequest,
OnlineDriveBrowseFilesResponse,
OnlineDriveDownloadFileRequest,
WebsiteCrawlMessage,
)
from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
PluginDatasourceProviderEntity,
)
from core.plugin.impl.base import BasePluginClient
from services.tools.tools_transform_service import ToolTransformService
class PluginDatasourceManager(BasePluginClient):
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
"""
Fetch datasource providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
if json_response.get("data"):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for datasource in declaration.get("datasources", []):
datasource["identity"]["provider"] = provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/datasources",
list[PluginDatasourceProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
)
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
for provider in response:
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
all_response = [local_file_datasource_provider] + response
for provider in all_response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in provider.declaration.datasources:
tool.identity.provider = provider.declaration.identity.name
return all_response
def fetch_datasource_provider(self, tenant_id: str, provider_id: str) -> PluginDatasourceProviderEntity:
"""
Fetch datasource provider for the given tenant and plugin.
"""
if provider_id == "langgenius/file/file":
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
tool_provider_id = DatasourceProviderID(provider_id)
def transformer(json_response: dict[str, Any]) -> dict:
data = json_response.get("data")
if data:
for datasource in data.get("declaration", {}).get("datasources", []):
datasource["identity"]["provider"] = tool_provider_id.provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/datasource",
PluginDatasourceProviderEntity,
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
transformer=transformer,
)
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for datasource in response.declaration.datasources:
datasource.identity.provider = response.declaration.identity.name
return response
def get_website_crawl(
self,
tenant_id: str,
user_id: str,
datasource_provider: str,
datasource_name: str,
credentials: dict[str, Any],
datasource_parameters: Mapping[str, Any],
provider_type: str,
) -> Generator[WebsiteCrawlMessage, None, None]:
"""
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
datasource_provider_id = GenericProviderID(datasource_provider)
return self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl",
WebsiteCrawlMessage,
data={
"user_id": user_id,
"data": {
"provider": datasource_provider_id.provider_name,
"datasource": datasource_name,
"credentials": credentials,
"datasource_parameters": datasource_parameters,
},
},
headers={
"X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
def get_online_document_pages(
self,
tenant_id: str,
user_id: str,
datasource_provider: str,
datasource_name: str,
credentials: dict[str, Any],
datasource_parameters: Mapping[str, Any],
provider_type: str,
) -> Generator[OnlineDocumentPagesMessage, None, None]:
"""
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
datasource_provider_id = GenericProviderID(datasource_provider)
return self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages",
OnlineDocumentPagesMessage,
data={
"user_id": user_id,
"data": {
"provider": datasource_provider_id.provider_name,
"datasource": datasource_name,
"credentials": credentials,
"datasource_parameters": datasource_parameters,
},
},
headers={
"X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
def get_online_document_page_content(
self,
tenant_id: str,
user_id: str,
datasource_provider: str,
datasource_name: str,
credentials: dict[str, Any],
datasource_parameters: GetOnlineDocumentPageContentRequest,
provider_type: str,
) -> Generator[DatasourceMessage, None, None]:
"""
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
datasource_provider_id = GenericProviderID(datasource_provider)
return self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content",
DatasourceMessage,
data={
"user_id": user_id,
"data": {
"provider": datasource_provider_id.provider_name,
"datasource": datasource_name,
"credentials": credentials,
"page": datasource_parameters.model_dump(),
},
},
headers={
"X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
def online_drive_browse_files(
self,
tenant_id: str,
user_id: str,
datasource_provider: str,
datasource_name: str,
credentials: dict[str, Any],
request: OnlineDriveBrowseFilesRequest,
provider_type: str,
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
"""
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
datasource_provider_id = GenericProviderID(datasource_provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/online_drive_browse_files",
OnlineDriveBrowseFilesResponse,
data={
"user_id": user_id,
"data": {
"provider": datasource_provider_id.provider_name,
"datasource": datasource_name,
"credentials": credentials,
"request": request.model_dump(),
},
},
headers={
"X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
yield from response
def online_drive_download_file(
self,
tenant_id: str,
user_id: str,
datasource_provider: str,
datasource_name: str,
credentials: dict[str, Any],
request: OnlineDriveDownloadFileRequest,
provider_type: str,
) -> Generator[DatasourceMessage, None, None]:
"""
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
datasource_provider_id = GenericProviderID(datasource_provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/online_drive_download_file",
DatasourceMessage,
data={
"user_id": user_id,
"data": {
"provider": datasource_provider_id.provider_name,
"datasource": datasource_name,
"credentials": credentials,
"request": request.model_dump(),
},
},
headers={
"X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
yield from response
def validate_provider_credentials(
self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any]
) -> bool:
"""
validate the credentials of the provider
"""
# datasource_provider_id = GenericProviderID(provider_id)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.result
return False
def _get_local_file_datasource_provider(self) -> dict[str, Any]:
return {
"id": "langgenius/file/file",
"plugin_id": "langgenius/file",
"provider": "file",
"plugin_unique_identifier": "langgenius/file:0.0.1@dify",
"declaration": {
"identity": {
"author": "langgenius",
"name": "file",
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
"icon": "https://assets.dify.ai/images/File%20Upload.svg",
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
},
"credentials_schema": [],
"provider_type": "local_file",
"datasources": [
{
"identity": {
"author": "langgenius",
"name": "upload-file",
"provider": "file",
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
},
"parameters": [],
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
}
],
},
}

View File

@ -4,7 +4,10 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
PluginToolProviderEntity,
)
from core.plugin.impl.base import BasePluginClient
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -197,6 +200,36 @@ class PluginToolManager(BasePluginClient):
return False
def validate_datasource_credentials(
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
) -> bool:
"""
validate the credentials of the datasource
"""
tool_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
"provider": tool_provider_id.provider_name,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": tool_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.result
return False
def get_runtime_parameters(
self,
tenant_id: str,

View File

@ -28,10 +28,12 @@ class Jieba(BaseKeyword):
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)
for text in texts:
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
@ -49,18 +51,17 @@ class Jieba(BaseKeyword):
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get("keywords_list")
keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
if not keywords:
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
else:
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
@ -239,7 +240,11 @@ class Jieba(BaseKeyword):
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
)
else:
keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk)
keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
segment.keywords = list(keywords)
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, segment.index_node_id, list(keywords)

View File

@ -0,0 +1,38 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, Field
class DatasourceStreamEvent(Enum):
"""
Datasource Stream event
"""
PROCESSING = "datasource_processing"
COMPLETED = "datasource_completed"
ERROR = "datasource_error"
class BaseDatasourceEvent(BaseModel):
pass
class DatasourceErrorEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.ERROR.value
error: str = Field(..., description="error message")
class DatasourceCompletedEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.COMPLETED.value
data: Mapping[str, Any] | list = Field(..., description="result")
total: Optional[int] = Field(default=0, description="total")
completed: Optional[int] = Field(default=0, description="completed")
time_consuming: Optional[float] = Field(default=0.0, description="time consuming")
class DatasourceProcessingEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.PROCESSING.value
total: Optional[int] = Field(..., description="total")
completed: Optional[int] = Field(..., description="completed")

View File

@ -13,3 +13,5 @@ class MetadataDataSource(Enum):
upload_file = "file_upload"
website_crawl = "website"
notion_import = "notion"
local_file = "file_upload"
online_document = "online_document"

View File

@ -1,7 +1,8 @@
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
from typing import Optional
from collections.abc import Mapping
from typing import Any, Optional
from configs import dify_config
from core.model_manager import ModelInstance
@ -13,6 +14,7 @@ from core.rag.splitter.fixed_text_splitter import (
)
from core.rag.splitter.text_splitter import TextSplitter
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
class BaseIndexProcessor(ABC):
@ -33,6 +35,14 @@ class BaseIndexProcessor(ABC):
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
raise NotImplementedError
@abstractmethod
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
raise NotImplementedError
@abstractmethod
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
raise NotImplementedError
@abstractmethod
def retrieve(
self,

View File

@ -1,19 +1,22 @@
"""Paragraph index processor."""
import uuid
from typing import Optional
from collections.abc import Mapping
from typing import Any, Optional
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.rag.models.document import Document, GeneralStructureChunk
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
@ -127,3 +130,34 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
paragraph = GeneralStructureChunk(**chunks)
documents = []
for content in paragraph.general_chunks:
metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
"doc_id": str(uuid.uuid4()),
"doc_hash": helper.generate_text_hash(content),
}
doc = Document(page_content=content, metadata=metadata)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)
keyword.add_texts(documents)
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
paragraph = GeneralStructureChunk(**chunks)
preview = []
for content in paragraph.general_chunks:
preview.append({"content": content})
return {"preview": preview, "total_segments": len(paragraph.general_chunks)}

View File

@ -1,20 +1,23 @@
"""Paragraph index processor."""
import uuid
from typing import Optional
from collections.abc import Mapping
from typing import Any, Optional
from configs import dify_config
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from extensions.ext_database import db
from libs import helper
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@ -202,3 +205,40 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_document.page_content = child_page_content
child_nodes.append(child_document)
return child_nodes
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
parent_childs = ParentChildStructureChunk(**chunks)
documents = []
for parent_child in parent_childs.parent_child_chunks:
metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
"doc_id": str(uuid.uuid4()),
"doc_hash": helper.generate_text_hash(parent_child.parent_content),
}
child_documents = []
for child in parent_child.child_contents:
child_metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
"doc_id": str(uuid.uuid4()),
"doc_hash": helper.generate_text_hash(child),
}
child_documents.append(ChildDocument(page_content=child, metadata=child_metadata))
doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=True)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
parent_childs = ParentChildStructureChunk(**chunks)
preview = []
for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
return {"preview": preview, "total_segments": len(parent_childs.parent_child_chunks)}

View File

@ -4,7 +4,8 @@ import logging
import re
import threading
import uuid
from typing import Optional
from collections.abc import Mapping
from typing import Any, Optional
import pandas as pd
from flask import Flask, current_app
@ -14,13 +15,15 @@ from core.llm_generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.rag.models.document import Document, QAStructureChunk
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
@ -161,6 +164,36 @@ class QAIndexProcessor(BaseIndexProcessor):
docs.append(doc)
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
qa_chunks = QAStructureChunk(**chunks)
documents = []
for qa_chunk in qa_chunks.qa_chunks:
metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
"doc_id": str(uuid.uuid4()),
"doc_hash": helper.generate_text_hash(qa_chunk.question),
"answer": qa_chunk.answer,
}
doc = Document(page_content=qa_chunk.question, metadata=metadata)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
else:
raise ValueError("Indexing technique must be high quality.")
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
qa_chunks = QAStructureChunk(**chunks)
preview = []
for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
return {"qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks)}
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():

View File

@ -35,6 +35,48 @@ class Document(BaseModel):
children: Optional[list[ChildDocument]] = None
class GeneralStructureChunk(BaseModel):
"""
General Structure Chunk.
"""
general_chunks: list[str]
class ParentChildChunk(BaseModel):
"""
Parent Child Chunk.
"""
parent_content: str
child_contents: list[str]
class ParentChildStructureChunk(BaseModel):
"""
Parent Child Structure Chunk.
"""
parent_child_chunks: list[ParentChildChunk]
class QAChunk(BaseModel):
"""
QA Chunk.
"""
question: str
answer: str
class QAStructureChunk(BaseModel):
"""
QAStructureChunk.
"""
qa_chunks: list[QAChunk]
class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems.

View File

@ -5,6 +5,7 @@ class RetrievalMethod(Enum):
SEMANTIC_SEARCH = "semantic_search"
FULL_TEXT_SEARCH = "full_text_search"
HYBRID_SEARCH = "hybrid_search"
KEYWORD_SEARCH = "keyword_search"
@staticmethod
def is_support_semantic_search(retrieval_method: str) -> bool:

View File

@ -262,6 +262,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
@ -283,7 +284,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
WorkflowNodeExecutionModel.triggered_from == triggered_from,
)
if self._app_id:
@ -317,6 +318,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
@ -334,7 +336,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
A list of NodeExecution instances
"""
# Get the database models using the new method
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config)
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from)
# Convert database models to domain models
domain_models = []

View File

@ -1,8 +1,8 @@
from collections.abc import Sequence
from typing import cast
from typing import Any, cast
from uuid import uuid4
from pydantic import Field
from pydantic import BaseModel, Field
from core.helper import encrypter
@ -93,3 +93,32 @@ class FileVariable(FileSegment, Variable):
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
class RAGPipelineVariable(BaseModel):
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
label: str = Field(description="label")
description: str | None = Field(description="description", default="")
variable: str = Field(description="variable key", default="")
max_length: int | None = Field(
description="max length, applicable to text-input, paragraph, and file-list", default=0
)
default_value: Any = Field(description="default value", default="")
placeholder: str | None = Field(description="placeholder", default="")
unit: str | None = Field(description="unit, applicable to Number", default="")
tooltips: str | None = Field(description="helpful text", default="")
allowed_file_types: list[str] | None = Field(
description="image, document, audio, video, custom.", default_factory=list
)
allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list)
allowed_file_upload_methods: list[str] | None = Field(
description="remote_url, local_file, tool_file.", default_factory=list
)
required: bool = Field(description="optional, default false", default=False)
options: list[str] | None = Field(default_factory=list)
class RAGPipelineVariableInput(BaseModel):
variable: RAGPipelineVariable
value: Any

View File

@ -1,3 +1,4 @@
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"

View File

@ -9,7 +9,13 @@ from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.variables.variables import RAGPipelineVariableInput
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
RAG_PIPELINE_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from core.workflow.enums import SystemVariableKey
from factories import variable_factory
@ -44,6 +50,10 @@ class VariablePool(BaseModel):
description="Conversation variables.",
default_factory=list,
)
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
default_factory=list,
)
def model_post_init(self, context: Any, /) -> None:
for key, value in self.system_variables.items():
@ -54,6 +64,9 @@ class VariablePool(BaseModel):
# Add conversation variables to the variable pool
for var in self.conversation_variables:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
# Add rag pipeline variables to the variable pool
for var in self.rag_pipeline_variables:
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var.variable.belong_to_node_id, var.variable.variable), var.value)
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""

View File

@ -20,6 +20,7 @@ class WorkflowType(StrEnum):
WORKFLOW = "workflow"
CHAT = "chat"
RAG_PIPELINE = "rag-pipeline"
class WorkflowExecutionStatus(StrEnum):

View File

@ -28,6 +28,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
DATASOURCE_INFO = "datasource_info"
LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id"

View File

@ -14,3 +14,10 @@ class SystemVariableKey(StrEnum):
APP_ID = "app_id"
WORKFLOW_ID = "workflow_id"
WORKFLOW_EXECUTION_ID = "workflow_run_id"
# RAG Pipeline
DOCUMENT_ID = "document_id"
BATCH = "batch"
DATASET_ID = "dataset_id"
DATASOURCE_TYPE = "datasource_type"
DATASOURCE_INFO = "datasource_info"
INVOKE_FROM = "invoke_from"

View File

@ -121,6 +121,7 @@ class Graph(BaseModel):
# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get("id")
if not node_id:
@ -141,6 +142,7 @@ class Graph(BaseModel):
node_config.get("id")
for node_config in root_node_configs
if node_config.get("data", {}).get("type", "") == NodeType.START.value
or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value
),
None,
)

View File

@ -175,7 +175,7 @@ class GraphEngine:
)
return
elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END:
if item.node_type in (NodeType.END, NodeType.KNOWLEDGE_INDEX):
self.graph_runtime_state.outputs = (
dict(item.route_node_state.node_run_result.outputs)
if item.route_node_state.node_run_result
@ -320,10 +320,10 @@ class GraphEngine:
raise e
# It may not be necessary, but it is necessary. :)
if (
self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower()
== NodeType.END.value
):
if self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() in [
NodeType.END.value,
NodeType.KNOWLEDGE_INDEX.value,
]:
break
previous_route_node_state = route_node_state

View File

@ -0,0 +1,3 @@
from .datasource_node import DatasourceNode
__all__ = ["DatasourceNode"]

View File

@ -0,0 +1,468 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
DatasourceParameter,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDriveDownloadFileRequest,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.file import File
from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.tool.exc import ToolFileError
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from models.model import UploadFile
from services.datasource_provider_service import DatasourceProviderService
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from .entities import DatasourceNodeData
from .exc import DatasourceNodeError, DatasourceParameterError
class DatasourceNode(BaseNode[DatasourceNodeData]):
"""
Datasource Node
"""
_node_data_cls = DatasourceNodeData
_node_type = NodeType.DATASOURCE
def _run(self) -> Generator:
"""
Run the datasource node
"""
node_data = cast(DatasourceNodeData, self.node_data)
variable_pool = self.graph_runtime_state.variable_pool
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
if not datasource_type:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = datasource_type.value
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
if not datasource_info:
raise DatasourceNodeError("Datasource info is not set")
datasource_info = datasource_info.value
# get datasource runtime
try:
from core.datasource.datasource_manager import DatasourceManager
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
except DatasourceNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__,
)
)
# get parameters
datasource_parameters = datasource_runtime.entity.parameters
parameters = self._generate_parameters(
datasource_parameters=datasource_parameters,
variable_pool=variable_pool,
node_data=self.node_data,
)
parameters_for_log = self._generate_parameters(
datasource_parameters=datasource_parameters,
variable_pool=variable_pool,
node_data=self.node_data,
for_log=True,
)
try:
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=self.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
online_document_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.get_online_document_page_content(
user_id=self.user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(
workspace_id=datasource_info.get("workspace_id"),
page_id=datasource_info.get("page").get("page_id"),
type=datasource_info.get("page").get("type"),
),
provider_type=datasource_type,
)
)
yield from self._transform_message(
messages=online_document_result,
parameters_for_log=parameters_for_log,
datasource_info=datasource_info,
)
case DatasourceProviderType.ONLINE_DRIVE:
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=self.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
online_drive_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.online_drive_download_file(
user_id=self.user_id,
request=OnlineDriveDownloadFileRequest(
key=datasource_info.get("key"),
bucket=datasource_info.get("bucket"),
),
provider_type=datasource_type,
)
)
yield from self._transform_message(
messages=online_drive_result,
parameters_for_log=parameters_for_log,
datasource_info=datasource_info,
)
case DatasourceProviderType.WEBSITE_CRAWL:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
**datasource_info,
"datasource_type": datasource_type,
},
)
)
case DatasourceProviderType.LOCAL_FILE:
related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError("File is not exist")
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
if not upload_file:
raise ValueError("Invalid upload file Info")
file_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
)
variable_pool.add([self.node_id, "file"], [file_info])
for key, value in datasource_info.items():
# construct new key list
new_key_list = ["file", key]
self._append_variables_recursively(
variable_pool=variable_pool,
node_id=self.node_id,
variable_key_list=new_key_list,
variable_value=value,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file_info": datasource_info,
"datasource_type": datasource_type,
},
)
)
case _:
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__,
)
)
except DatasourceNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__,
)
)
def _generate_parameters(
self,
*,
datasource_parameters: Sequence[DatasourceParameter],
variable_pool: VariablePool,
node_data: DatasourceNodeData,
for_log: bool = False,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
result: dict[str, Any] = {}
if node_data.datasource_parameters:
for parameter_name in node_data.datasource_parameters:
parameter = datasource_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
datasource_input = node_data.datasource_parameters[parameter_name]
if datasource_input.type == "variable":
variable = variable_pool.get(datasource_input.value)
if variable is None:
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
parameter_value = variable.value
elif datasource_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(datasource_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
result[parameter_name] = parameter_value
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _append_variables_recursively(
self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue
):
"""
Append variables recursively
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
variable_pool.add([node_id] + variable_key_list, variable_value)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
for key, value in variable_value.items():
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: DatasourceNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
result = {}
if node_data.datasource_parameters:
for parameter_name in node_data.datasource_parameters:
input = node_data.datasource_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()}
return result
def _transform_message(
self,
messages: Generator[DatasourceMessage, None, None],
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict] = []
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
DatasourceMessage.MessageType.IMAGE_LINK,
DatasourceMessage.MessageType.BINARY_LINK,
DatasourceMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.DATASOURCE_FILE)
else:
transfer_method = FileTransferMethod.DATASOURCE_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(UploadFile).where(UploadFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"datasource_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mime_type),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
elif message.type == DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(UploadFile).where(UploadFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"datasource_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.DATASOURCE_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
elif message.type == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
)
elif message.type == DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {})
agent_execution_metadata = {
key: value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
json.append(message.message.json_object)
elif message.type == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
elif message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
)
else:
variables[variable_name] = variable_value
elif message.type == DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"json": json, "files": files, **variables, "text": text},
metadata={
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
},
inputs=parameters_for_log,
)
)
@classmethod
def version(cls) -> str:
return "1"

View File

@ -0,0 +1,41 @@
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.workflow.nodes.base.entities import BaseNodeData
class DatasourceEntity(BaseModel):
plugin_id: str
provider_name: str # redundancy
provider_type: str
datasource_name: Optional[str] = "local_file"
datasource_configurations: dict[str, Any] | None = None
plugin_unique_identifier: str | None = None # redundancy
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
class DatasourceInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Optional[Literal["mixed", "variable", "constant"]] = None
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool):
raise ValueError("value must be a string, int, float, or bool")
return typ
datasource_parameters: dict[str, DatasourceInput] | None = None

View File

@ -0,0 +1,16 @@
class DatasourceNodeError(ValueError):
"""Base exception for datasource node errors."""
pass
class DatasourceParameterError(DatasourceNodeError):
"""Exception raised for errors in datasource parameters."""
pass
class DatasourceFileError(DatasourceNodeError):
"""Exception raised for errors related to datasource files."""
pass

View File

@ -7,12 +7,14 @@ class NodeType(StrEnum):
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
KNOWLEDGE_INDEX = "knowledge-index"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
DATASOURCE = "datasource"
VARIABLE_AGGREGATOR = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"

View File

@ -0,0 +1,3 @@
from .knowledge_index_node import KnowledgeIndexNode
__all__ = ["KnowledgeIndexNode"]

View File

@ -0,0 +1,159 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel
from core.workflow.nodes.base import BaseNodeData
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
reranking_provider_name: str
reranking_model_name: str
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class RetrievalSetting(BaseModel):
"""
Retrieval Setting.
"""
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
top_k: int
score_threshold: Optional[float] = 0.5
score_threshold_enabled: bool = False
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting
class FileInfo(BaseModel):
"""
File Info.
"""
file_id: str
class OnlineDocumentIcon(BaseModel):
"""
Document Icon.
"""
icon_url: str
icon_type: str
icon_emoji: str
class OnlineDocumentInfo(BaseModel):
"""
Online document info.
"""
provider: str
workspace_id: str
page_id: str
page_type: str
icon: OnlineDocumentIcon
class WebsiteInfo(BaseModel):
"""
website import info.
"""
provider: str
url: str
class GeneralStructureChunk(BaseModel):
"""
General Structure Chunk.
"""
general_chunks: list[str]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
class ParentChildChunk(BaseModel):
"""
Parent Child Chunk.
"""
parent_content: str
child_contents: list[str]
class ParentChildStructureChunk(BaseModel):
"""
Parent Child Structure Chunk.
"""
parent_child_chunks: list[ParentChildChunk]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
class KnowledgeIndexNodeData(BaseNodeData):
"""
Knowledge index Node Data.
"""
type: str = "knowledge-index"
chunk_structure: str
index_chunk_variable_selector: list[str]

View File

@ -0,0 +1,22 @@
class KnowledgeIndexNodeError(ValueError):
"""Base class for KnowledgeIndexNode errors."""
class ModelNotExistError(KnowledgeIndexNodeError):
"""Raised when the model does not exist."""
class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
"""Raised when the model credentials are not initialized."""
class ModelNotSupportedError(KnowledgeIndexNodeError):
"""Raised when the model is not supported."""
class ModelQuotaExceededError(KnowledgeIndexNodeError):
"""Raised when the model provider quota is exceeded."""
class InvalidModelTypeError(KnowledgeIndexNodeError):
"""Raised when the model is not a Large Language Model."""

View File

@ -0,0 +1,165 @@
import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any, cast
from sqlalchemy import func
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from ..base import BaseNode
from .entities import KnowledgeIndexNodeData
from .exc import (
KnowledgeIndexNodeError,
)
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
_node_data_cls = KnowledgeIndexNodeData # type: ignore
_node_type = NodeType.KNOWLEDGE_INDEX
def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeIndexNodeData, self.node_data)
variable_pool = self.graph_runtime_state.variable_pool
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id:
raise KnowledgeIndexNodeError("Dataset ID is required.")
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
# extract variables
variable = variable_pool.get(node_data.index_chunk_variable_selector)
if not variable:
raise KnowledgeIndexNodeError("Index chunk variable is required.")
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
if invoke_from:
is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value
else:
is_preview = False
chunks = variable.value
variables = {"chunks": chunks}
if not chunks:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
)
# index knowledge
try:
if is_preview:
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
outputs=outputs,
)
results = self._invoke_knowledge_index(
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
)
except KnowledgeIndexNodeError as e:
logger.warning("Error when running knowledge index node")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
def _invoke_knowledge_index(
self,
dataset: Dataset,
node_data: KnowledgeIndexNodeData,
chunks: Mapping[str, Any],
variable_pool: VariablePool,
) -> Any:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if not document_id:
raise KnowledgeIndexNodeError("Document ID is required.")
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
if not batch:
raise KnowledgeIndexNodeError("Batch is required.")
document = db.session.query(Document).filter_by(id=document_id.value).first()
if not document:
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
index_processor.index(dataset, document, chunks)
indexing_end_at = time.perf_counter()
document.indexing_latency = indexing_end_at - indexing_start_at
# update document status
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.word_count = (
db.session.query(func.sum(DocumentSegment.word_count))
.filter(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,
)
.scalar()
)
db.session.add(document)
# update document segment status
db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
)
db.session.commit()
return {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"batch": batch.value,
"document_id": document.id,
"document_name": document.name,
"created_at": document.created_at.timestamp(),
"display_status": document.indexing_status,
}
def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
return index_processor.format_preview(chunks)
@classmethod
def version(cls) -> str:
return "1"

View File

@ -57,10 +57,6 @@ class MultipleRetrievalConfig(BaseModel):
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str

View File

@ -4,12 +4,14 @@ from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code import CodeNode
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end import EndNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.list_operator import ListOperatorNode
from core.workflow.nodes.llm import LLMNode
@ -124,4 +126,12 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
LATEST_VERSION: AgentNode,
"1": AgentNode,
},
NodeType.DATASOURCE: {
LATEST_VERSION: DatasourceNode,
"1": DatasourceNode,
},
NodeType.KNOWLEDGE_INDEX: {
LATEST_VERSION: KnowledgeIndexNode,
"1": KnowledgeIndexNode,
},
}

View File

@ -61,6 +61,7 @@ def build_from_mapping(
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file,
}
build_func = build_functions.get(transfer_method)
@ -305,6 +306,53 @@ def _build_from_tool_file(
)
def _build_from_datasource_file(
*,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
datasource_file = (
db.session.query(UploadFile)
.filter(
UploadFile.id == mapping.get("datasource_file_id"),
UploadFile.tenant_id == tenant_id,
)
.first()
)
if datasource_file is None:
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
return File(
id=mapping.get("id"),
tenant_id=tenant_id,
filename=datasource_file.name,
type=file_type,
transfer_method=transfer_method,
remote_url=datasource_file.source_url,
related_id=datasource_file.id,
extension=extension,
mime_type=datasource_file.mime_type,
size=datasource_file.size,
storage_key=datasource_file.key,
)
def _is_file_valid_with_config(
*,
input_file_type: str,

View File

@ -36,7 +36,10 @@ from core.variables.variables import (
StringVariable,
Variable,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
)
class UnsupportedSegmentTypeError(Exception):
@ -75,6 +78,12 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if not mapping.get("variable"):
raise VariableError("missing variable")
return mapping["variable"]
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
"""
This factory function is used to create the environment variable or the conversation variable,

View File

@ -56,6 +56,13 @@ external_knowledge_info_fields = {
doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
icon_info_fields = {
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": fields.String,
}
dataset_detail_fields = {
"id": fields.String,
"name": fields.String,
@ -81,6 +88,13 @@ dataset_detail_fields = {
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
"doc_metadata": fields.List(fields.Nested(doc_metadata_fields)),
"built_in_field_enabled": fields.Boolean,
"pipeline_id": fields.String,
"runtime_mode": fields.String,
"chunk_structure": fields.String,
"icon_info": fields.Nested(icon_info_fields),
"is_published": fields.Boolean,
"total_documents": fields.Integer,
"total_available_documents": fields.Integer,
}
dataset_query_detail_fields = {

View File

@ -0,0 +1,164 @@
from flask_restful import fields # type: ignore
from fields.workflow_fields import workflow_partial_fields
from libs.helper import AppIconUrlField, TimestampField
pipeline_detail_kernel_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
}
related_app_list = {
"data": fields.List(fields.Nested(pipeline_detail_kernel_fields)),
"total": fields.Integer,
}
app_detail_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon": fields.String,
"icon_background": fields.String,
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
"tracing": fields.Raw,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
}
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
app_partial_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String(attribute="desc_or_prompt"),
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"tags": fields.List(fields.Nested(tag_fields)),
}
app_pagination_fields = {
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(app_partial_fields), attribute="items"),
}
template_fields = {
"name": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"description": fields.String,
"mode": fields.String,
}
template_list_fields = {
"data": fields.List(fields.Nested(template_fields)),
}
site_fields = {
"access_token": fields.String(attribute="code"),
"code": fields.String,
"title": fields.String,
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"description": fields.String,
"default_language": fields.String,
"chat_color_theme": fields.String,
"chat_color_theme_inverted": fields.Boolean,
"customize_domain": fields.String,
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"customize_token_strategy": fields.String,
"prompt_public": fields.Boolean,
"app_base_url": fields.String,
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
}
deleted_tool_fields = {
"type": fields.String,
"tool_name": fields.String,
"provider_id": fields.String,
}
app_detail_fields_with_site = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"enable_site": fields.Boolean,
"enable_api": fields.Boolean,
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
"site": fields.Nested(site_fields),
"api_base_url": fields.String,
"use_icon_as_answer_icon": fields.Boolean,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
}
app_site_fields = {
"app_id": fields.String,
"access_token": fields.String(attribute="code"),
"code": fields.String,
"title": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"description": fields.String,
"default_language": fields.String,
"customize_domain": fields.String,
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"customize_token_strategy": fields.String,
"prompt_public": fields.Boolean,
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
}
leaked_dependency_fields = {"type": fields.String, "value": fields.Raw, "current_identifier": fields.String}
pipeline_import_fields = {
"id": fields.String,
"status": fields.String,
"pipeline_id": fields.String,
"dataset_id": fields.String,
"current_dsl_version": fields.String,
"imported_dsl_version": fields.String,
"error": fields.String,
}
pipeline_import_check_dependencies_fields = {
"leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)),
}

View File

@ -40,6 +40,23 @@ conversation_variable_fields = {
"description": fields.String,
}
pipeline_variable_fields = {
"label": fields.String,
"variable": fields.String,
"type": fields.String,
"belong_to_node_id": fields.String,
"max_length": fields.Integer,
"required": fields.Boolean,
"unit": fields.String,
"default_value": fields.Raw,
"options": fields.List(fields.String),
"placeholder": fields.String,
"tooltips": fields.String,
"allowed_file_types": fields.List(fields.String),
"allow_file_extension": fields.List(fields.String),
"allow_file_upload_methods": fields.List(fields.String),
}
workflow_fields = {
"id": fields.String,
"graph": fields.Raw(attribute="graph_dict"),
@ -55,6 +72,7 @@ workflow_fields = {
"tool_published": fields.Boolean,
"environment_variables": fields.List(EnvironmentVariableField()),
"conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
"rag_pipeline_variables": fields.List(fields.Nested(pipeline_variable_fields)),
}
workflow_partial_fields = {

View File

@ -0,0 +1 @@
{"not_installed": [], "plugin_install_failed": []}

View File

@ -0,0 +1,113 @@
"""add_pipeline_info
Revision ID: b35c3db83d09
Revises: d28f2004b072
Create Date: 2025-05-15 15:58:05.179877
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'b35c3db83d09'
down_revision = '0ab65e1cc7fa'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('pipeline_built_in_templates',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('pipeline_id', models.types.StringUUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=False),
sa.Column('icon', sa.JSON(), nullable=False),
sa.Column('copyright', sa.String(length=255), nullable=False),
sa.Column('privacy_policy', sa.String(length=255), nullable=False),
sa.Column('position', sa.Integer(), nullable=False),
sa.Column('install_count', sa.Integer(), nullable=False),
sa.Column('language', sa.String(length=255), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey')
)
op.create_table('pipeline_customized_templates',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('pipeline_id', models.types.StringUUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=False),
sa.Column('icon', sa.JSON(), nullable=False),
sa.Column('position', sa.Integer(), nullable=False),
sa.Column('install_count', sa.Integer(), nullable=False),
sa.Column('language', sa.String(length=255), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
)
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False)
op.create_table('pipelines',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), server_default=sa.text("''::character varying"), nullable=False),
sa.Column('mode', sa.String(length=255), nullable=False),
sa.Column('workflow_id', models.types.StringUUID(), nullable=True),
sa.Column('is_public', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('is_published', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_pkey')
)
op.create_table('tool_builtin_datasource_providers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
sa.Column('user_id', models.types.StringUUID(), nullable=False),
sa.Column('provider', sa.String(length=256), nullable=False),
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_builtin_datasource_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_datasource_provider')
)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True))
batch_op.add_column(sa.Column('icon_info', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
batch_op.add_column(sa.Column('runtime_mode', sa.String(length=255), server_default=sa.text("'general'::character varying"), nullable=True))
batch_op.add_column(sa.Column('pipeline_id', models.types.StringUUID(), nullable=True))
batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=True))
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.drop_column('rag_pipeline_variables')
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('chunk_structure')
batch_op.drop_column('pipeline_id')
batch_op.drop_column('runtime_mode')
batch_op.drop_column('icon_info')
batch_op.drop_column('keyword_number')
op.drop_table('tool_builtin_datasource_providers')
op.drop_table('pipelines')
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
batch_op.drop_index('pipeline_customized_template_tenant_idx')
op.drop_table('pipeline_customized_templates')
op.drop_table('pipeline_built_in_templates')
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""add_pipeline_info_2
Revision ID: abb18a379e62
Revises: b35c3db83d09
Create Date: 2025-05-16 16:59:16.423127
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'abb18a379e62'
down_revision = 'b35c3db83d09'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipelines', schema=None) as batch_op:
batch_op.drop_column('mode')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipelines', schema=None) as batch_op:
batch_op.add_column(sa.Column('mode', sa.VARCHAR(length=255), autoincrement=False, nullable=False))
# ### end Alembic commands ###

View File

@ -0,0 +1,70 @@
"""add_pipeline_info_3
Revision ID: c459994abfa8
Revises: abb18a379e62
Create Date: 2025-05-30 00:33:14.068312
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'c459994abfa8'
down_revision = 'abb18a379e62'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('datasource_oauth_params',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('system_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'),
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx')
)
op.create_table('datasource_providers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('auth_type', sa.String(length=255), nullable=False),
sa.Column('encrypted_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'),
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_provider_plugin_id_provider_idx')
)
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=False))
batch_op.add_column(sa.Column('yaml_content', sa.Text(), nullable=False))
batch_op.drop_column('pipeline_id')
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=False))
batch_op.add_column(sa.Column('yaml_content', sa.Text(), nullable=False))
batch_op.drop_column('pipeline_id')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('pipeline_id', sa.UUID(), autoincrement=False, nullable=False))
batch_op.drop_column('yaml_content')
batch_op.drop_column('chunk_structure')
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('pipeline_id', sa.UUID(), autoincrement=False, nullable=False))
batch_op.drop_column('yaml_content')
batch_op.drop_column('chunk_structure')
op.drop_table('datasource_providers')
op.drop_table('datasource_oauth_params')
# ### end Alembic commands ###

Some files were not shown because too many files have changed in this diff Show More