dify/api/controllers/console/datasets/data_source.py

424 lines
17 KiB
Python

import json
from collections.abc import Generator
from datetime import datetime
from typing import Any, Literal, cast
from uuid import UUID
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_serializer
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.fields import SimpleResultResponse, TextContentResponse
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.entities.knowledge_entities import IndexingEstimate
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.datetime_utils import naive_utc_now
from libs.helper import dump_response, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task
from .. import console_ns
from ..wraps import account_initialization_required, setup_required
class NotionEstimatePayload(BaseModel):
notion_info_list: list[dict[str, Any]]
process_rule: dict[str, Any]
doc_form: str = Field(default="text_model")
doc_language: str = Field(default="English")
class DataSourceNotionListQuery(BaseModel):
dataset_id: str | None = Field(default=None, description="Dataset ID")
credential_id: str = Field(..., description="Credential ID", min_length=1)
class DataSourceNotionPreviewQuery(BaseModel):
credential_id: str = Field(..., description="Credential ID", min_length=1)
class DataSourceIntegrateIconResponse(ResponseModel):
type: str | None = None
url: str | None = None
emoji: str | None = None
class DataSourceIntegratePageResponse(ResponseModel):
page_name: str
page_id: str
page_icon: DataSourceIntegrateIconResponse | None
parent_id: str
type: str
class DataSourceIntegrateWorkspaceResponse(ResponseModel):
workspace_name: str | None
workspace_id: str | None
workspace_icon: str | None
pages: list[DataSourceIntegratePageResponse]
total: int
class DataSourceIntegrateResponse(ResponseModel):
id: str | None
provider: str
created_at: datetime | int | None
is_bound: bool
disabled: bool | None
link: str
source_info: DataSourceIntegrateWorkspaceResponse | None
@field_serializer("created_at")
def serialize_created_at(self, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DataSourceIntegrateListResponse(ResponseModel):
data: list[DataSourceIntegrateResponse]
class NotionIntegratePageResponse(ResponseModel):
page_name: str
page_id: str
page_icon: DataSourceIntegrateIconResponse | None
parent_id: str | None
type: str
is_bound: bool
class NotionIntegrateWorkspaceResponse(ResponseModel):
workspace_name: str | None
workspace_id: str | None
workspace_icon: str | None
pages: list[NotionIntegratePageResponse]
class NotionIntegrateInfoListResponse(ResponseModel):
notion_info: list[NotionIntegrateWorkspaceResponse]
register_schema_models(console_ns, NotionEstimatePayload)
register_response_schema_models(
console_ns,
DataSourceIntegrateListResponse,
IndexingEstimate,
NotionIntegrateInfoListResponse,
SimpleResultResponse,
TextContentResponse,
)
@console_ns.route(
"/data-source/integrates",
"/data-source/integrates/<uuid:binding_id>/<string:action>",
)
class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[DataSourceIntegrateListResponse.__name__])
def get(self) -> tuple[dict[str, Any], int]:
_, current_tenant_id = current_account_with_tenant()
# get workspace data source integrates
data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_tenant_id,
DataSourceOauthBinding.disabled == False,
)
).all()
base_url = request.url_root.rstrip("/")
data_source_oauth_base_path = "/console/api/oauth/data-source"
providers = ["notion"]
integrate_data = []
for provider in providers:
# existing_integrate = next((ai for ai in data_source_integrates if ai.provider == provider), None)
existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates)
if existing_integrates:
for existing_integrate in list(existing_integrates):
integrate_data.append(
{
"id": existing_integrate.id,
"provider": provider,
"created_at": existing_integrate.created_at,
"is_bound": True,
"disabled": existing_integrate.disabled,
"source_info": existing_integrate.source_info,
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
}
)
else:
integrate_data.append(
{
"id": None,
"provider": provider,
"created_at": None,
"source_info": None,
"is_bound": False,
"disabled": None,
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
}
)
return dump_response(DataSourceIntegrateListResponse, {"data": integrate_data}), 200
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def patch(self, binding_id: UUID, action: Literal["enable", "disable"]) -> tuple[dict[str, str], int]:
_, current_tenant_id = current_account_with_tenant()
binding_id_str = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.id == binding_id_str, DataSourceOauthBinding.tenant_id == current_tenant_id
)
).scalar_one_or_none()
if data_source_binding is None:
raise NotFound("Data source binding not found.")
# enable binding
match action:
case "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is not disabled.")
# disable binding
case "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is disabled.")
return {"result": "success"}, 200
@console_ns.route("/notion/pre-import/pages")
class DataSourceNotionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.doc(params=query_params_from_model(DataSourceNotionListQuery))
@console_ns.response(200, "Success", console_ns.models[NotionIntegrateInfoListResponse.__name__])
def get(self) -> tuple[dict[str, Any], int]:
current_user, current_tenant_id = current_account_with_tenant()
query = DataSourceNotionListQuery.model_validate(request.args.to_dict(flat=True))
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
credential_id=query.credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not credential:
raise NotFound("Credential not found.")
exist_page_ids = []
with sessionmaker(db.engine).begin() as session:
# import notion in the exist dataset
if query.dataset_id:
dataset = DatasetService.get_dataset(query.dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
raise ValueError("Dataset is not notion type.")
documents = session.scalars(
select(Document).where(
Document.dataset_id == query.dataset_id,
Document.tenant_id == current_tenant_id,
Document.data_source_type == "notion_import",
Document.enabled.is_(True),
)
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion_datasource",
datasource_name="notion_datasource",
tenant_id=current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
datasource_provider_service = DatasourceProviderService()
if credential:
datasource_runtime.runtime.credentials = credential
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters={},
provider_type=datasource_runtime.datasource_provider_type(),
)
)
try:
pages = []
workspace_info = {}
for message in online_document_result:
result = message.result
for info in result:
workspace_info = {
"workspace_id": info.workspace_id,
"workspace_name": info.workspace_name,
"workspace_icon": info.workspace_icon,
}
for page in info.pages:
page_info = {
"page_id": page.page_id,
"page_name": page.page_name,
"type": page.type,
"parent_id": page.parent_id,
"is_bound": page.page_id in exist_page_ids,
"page_icon": page.page_icon,
}
pages.append(page_info)
except Exception as e:
raise e
notion_info = [{**workspace_info, "pages": pages}] if workspace_info else []
return dump_response(NotionIntegrateInfoListResponse, {"notion_info": notion_info}), 200
@console_ns.route("/notion/pages/<uuid:page_id>/<string:page_type>/preview")
class DataSourceNotionPreviewApi(Resource):
"""Preview one authorized Notion page through the datasource credential."""
@setup_required
@login_required
@account_initialization_required
@console_ns.doc(params=query_params_from_model(DataSourceNotionPreviewQuery))
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
def get(self, page_id: UUID, page_type: str) -> tuple[dict[str, str], int]:
_, current_tenant_id = current_account_with_tenant()
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict(flat=True))
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
credential_id=query.credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
page_id_str = str(page_id)
extractor = NotionExtractor(
notion_workspace_id="",
notion_obj_id=page_id_str,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=current_tenant_id,
)
text_docs = extractor.extract()
return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
@console_ns.route("/datasets/notion-indexing-estimate")
class DataSourceNotionIndexingEstimateApi(Resource):
"""Estimate indexing work for selected Notion pages."""
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
@console_ns.response(200, "Success", console_ns.models[IndexingEstimate.__name__])
def post(self) -> tuple[dict[str, Any], int]:
_, current_tenant_id = current_account_with_tenant()
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
# validate args
DocumentService.estimate_args_validate(args)
notion_info_list = payload.notion_info_list
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_tenant_id,
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(
current_tenant_id,
extract_settings,
args["process_rule"],
args["doc_form"],
args["doc_language"],
)
return dump_response(IndexingEstimate, response), 200
@console_ns.route("/datasets/<uuid:dataset_id>/notion/sync")
class DataSourceNotionDatasetSyncApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID) -> tuple[dict[str, str], int]:
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
documents = DocumentService.get_document_by_dataset_id(dataset_id_str)
for document in documents:
document_indexing_sync_task.delay(dataset_id_str, document.id)
return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync")
class DataSourceNotionDocumentSyncApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID, document_id: UUID) -> tuple[dict[str, str], int]:
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset_id_str, document_id_str)
if document is None:
raise NotFound("Document not found.")
document_indexing_sync_task.delay(dataset_id_str, document_id_str)
return {"result": "success"}, 200