From 829e6f0d1af3dc55011fbe331efbbf14b7c71336 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 28 Jul 2025 19:29:07 +0800 Subject: [PATCH] add old auth transform --- api/commands.py | 187 +++++++++++++++++- .../console/datasets/data_source.py | 111 +++++++---- api/controllers/console/datasets/datasets.py | 3 + .../console/datasets/datasets_document.py | 2 + api/controllers/console/datasets/website.py | 2 + api/core/indexing_runner.py | 2 + .../rag/extractor/entity/extract_setting.py | 3 +- api/core/rag/extractor/extract_processor.py | 4 + .../firecrawl/firecrawl_web_extractor.py | 9 +- .../rag/extractor/jina_reader_extractor.py | 7 +- api/core/rag/extractor/notion_extractor.py | 37 ++-- .../rag/extractor/watercrawl/extractor.py | 9 +- api/services/datasource_provider_service.py | 27 +++ api/services/website_service.py | 73 +++++-- 14 files changed, 390 insertions(+), 86 deletions(-) diff --git a/api/commands.py b/api/commands.py index 1147db1632..85856ca2a8 100644 --- a/api/commands.py +++ b/api/commands.py @@ -12,11 +12,15 @@ from werkzeug.exceptions import NotFound from configs import dify_config from constants.languages import languages -from core.plugin.entities.plugin import DatasourceProviderID, ToolProviderID +from core.helper import encrypter +from core.plugin.entities.plugin import DatasourceProviderID, PluginInstallationSource, ToolProviderID +from core.plugin.impl.datasource import PluginDatasourceManager +from core.plugin.impl.plugin import PluginInstaller from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.models.document import Document +from core.tools.entities.tool_entities import CredentialType from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params from events.app_event import app_was_created from extensions.ext_database import db @@ -29,10 +33,12 @@ from models import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation -from models.oauth import DatasourceOauthParamConfig +from models.oauth import DatasourceOauthParamConfig, DatasourceProvider from models.provider import Provider, ProviderModel +from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding from models.tools import ToolOAuthSystemClient from services.account_service import AccountService, RegisterService, TenantService +from services.auth import firecrawl from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration @@ -1248,3 +1254,180 @@ def setup_datasource_oauth_client(provider, client_params): click.echo(click.style(f"plugin_id: {plugin_id}", fg="green")) click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green")) click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green")) + +@click.command("transform-datasource-credentials", help="Transform datasource credentials.") +def transform_datasource_credentials(): + """ + Transform datasource credentials + """ + try: + installer_manager = PluginInstaller() + plugin_migration = PluginMigration() + + notion_plugin_id = "langgenius/notion_datasource" + firecrawl_plugin_id = "langgenius/firecrawl_datasource" + jina_plugin_id = "langgenius/jina_datasource" + notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) + firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) + jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) + oauth_credential_type = CredentialType.OAUTH2 + api_key_credential_type = CredentialType.API_KEY + + + # deal notion credentials + deal_notion_count = 0 + notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all() + notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {} + for credential in notion_credentials: + tenant_id = credential.tenant_id + if tenant_id not in notion_credentials_tenant_mapping: + notion_credentials_tenant_mapping[tenant_id] = [] + notion_credentials_tenant_mapping[tenant_id].append(credential) + for tenant_id, credentials in notion_credentials_tenant_mapping.items(): + # check notion plugin is installed + installed_plugins = installer_manager.list_plugins(tenant_id) + installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] + if notion_plugin_id not in installed_plugins_ids: + if notion_plugin_unique_identifier: + # install notion plugin + installer_manager.install_from_identifiers( + tenant_id, + [notion_plugin_unique_identifier], + PluginInstallationSource.Marketplace, + metas=[ + { + "plugin_unique_identifier": notion_plugin_unique_identifier, + } + ], + ) + auth_count = 0 + for credential in credentials: + auth_count += 1 + # get credential oauth params + access_token = credential.access_token + # notion info + notion_info = credential.source_info + workspace_id = notion_info.get("workspace_id") + workspace_name = notion_info.get("workspace_name") + workspace_icon = notion_info.get("workspace_icon") + new_credentials = { + "integration_secret": encrypter.encrypt_token(tenant_id, access_token), + "workspace_id": workspace_id, + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + } + datasource_provider = DatasourceProvider( + provider="notion", + tenant_id=tenant_id, + plugin_id=notion_plugin_id, + auth_type=oauth_credential_type.value, + encrypted_credentials=new_credentials, + name=f"Auth {auth_count}", + avatar_url=workspace_icon or "default", + is_default=False, + ) + db.session.add(datasource_provider) + deal_notion_count += 1 + db.session.commit() + # deal firecrawl credentials + deal_firecrawl_count = 0 + firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all() + firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} + for credential in firecrawl_credentials: + tenant_id = credential.tenant_id + if tenant_id not in firecrawl_credentials_tenant_mapping: + firecrawl_credentials_tenant_mapping[tenant_id] = [] + firecrawl_credentials_tenant_mapping[tenant_id].append(credential) + for tenant_id, credentials in firecrawl_credentials_tenant_mapping.items(): + # check firecrawl plugin is installed + installed_plugins = installer_manager.list_plugins(tenant_id) + installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] + if firecrawl_plugin_id not in installed_plugins_ids: + if firecrawl_plugin_unique_identifier: + # install firecrawl plugin + installer_manager.install_from_identifiers( + tenant_id, + [firecrawl_plugin_unique_identifier], + PluginInstallationSource.Marketplace, + metas=[ + { + "plugin_unique_identifier": firecrawl_plugin_unique_identifier, + } + ], + ) + auth_count = 0 + for credential in credentials: + auth_count += 1 + # get credential api key + api_key = credential.credentials.get("config", {}).get("api_key") + base_url = credential.credentials.get("config", {}).get("base_url") + new_credentials = { + "firecrawl_api_key": api_key, + "base_url": base_url, + } + datasource_provider = DatasourceProvider( + provider="firecrawl", + tenant_id=tenant_id, + plugin_id=firecrawl_plugin_id, + auth_type=api_key_credential_type.value, + encrypted_credentials=new_credentials, + name=f"Auth {auth_count}", + avatar_url="default", + is_default=False, + ) + db.session.add(datasource_provider) + deal_firecrawl_count += 1 + db.session.commit() + # deal jina credentials + deal_jina_count = 0 + jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jina").all() + jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} + for credential in jina_credentials: + tenant_id = credential.tenant_id + if tenant_id not in jina_credentials_tenant_mapping: + jina_credentials_tenant_mapping[tenant_id] = [] + jina_credentials_tenant_mapping[tenant_id].append(credential) + for tenant_id, credentials in jina_credentials_tenant_mapping.items(): + # check jina plugin is installed + installed_plugins = installer_manager.list_plugins(tenant_id) + installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] + if jina_plugin_id not in installed_plugins_ids: + if jina_plugin_unique_identifier: + # install jina plugin + installer_manager.install_from_identifiers( + tenant_id, + [jina_plugin_unique_identifier], + PluginInstallationSource.Marketplace, + metas=[ + { + "plugin_unique_identifier": jina_plugin_unique_identifier, + } + ], + ) + auth_count = 0 + for credential in credentials: + auth_count += 1 + # get credential api key + api_key = credential.credentials.get("config", {}).get("api_key") + new_credentials = { + "integration_secret": api_key, + } + datasource_provider = DatasourceProvider( + provider="jina", + tenant_id=tenant_id, + plugin_id=jina_plugin_id, + auth_type=api_key_credential_type.value, + encrypted_credentials=new_credentials, + name=f"Auth {auth_count}", + avatar_url="default", + is_default=False, + ) + db.session.add(datasource_provider) + deal_jina_count += 1 + db.session.commit() + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green")) + click.echo(click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")) + click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green")) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 39f8ab5787..ecf5d7d336 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,4 +1,5 @@ import json +from typing import Generator, cast from flask import request from flask_login import current_user @@ -9,6 +10,8 @@ from werkzeug.exceptions import NotFound from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required +from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.indexing_runner import IndexingRunner from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.notion_extractor import NotionExtractor @@ -17,7 +20,9 @@ from fields.data_source_fields import integrate_list_fields, integrate_notion_in from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import DataSourceOauthBinding, Document +from models.oauth import DatasourceProvider from services.dataset_service import DatasetService, DocumentService +from services.datasource_provider_service import DatasourceProviderService from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -112,6 +117,18 @@ class DataSourceNotionListApi(Resource): @marshal_with(integrate_notion_info_list_fields) def get(self): dataset_id = request.args.get("dataset_id", default=None, type=str) + credential_id = request.args.get("credential_id", default=None, type=str) + if not credential_id: + raise ValueError("Credential id is required.") + datasource_provider_service = DatasourceProviderService() + credential = datasource_provider_service.get_real_credential_by_id( + tenant_id=current_user.current_tenant_id, + credential_id=credential_id, + provider="notion", + plugin_id="langgenius/notion_datasource", + ) + if not credential: + raise NotFound("Credential not found.") exist_page_ids = [] with Session(db.engine) as session: # import notion in the exist dataset @@ -135,31 +152,49 @@ class DataSourceNotionListApi(Resource): data_source_info = json.loads(document.data_source_info) exist_page_ids.append(data_source_info["notion_page_id"]) # get all authorized pages - data_source_bindings = session.scalars( - select(DataSourceOauthBinding).filter_by( - tenant_id=current_user.current_tenant_id, provider="notion", disabled=False + from core.datasource.datasource_manager import DatasourceManager + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id="langgenius/notion_datasource/notion", + datasource_name="notion", + tenant_id=current_user.current_tenant_id, + datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, + ) + datasource_provider_service = DatasourceProviderService() + if credential: + datasource_runtime.runtime.credentials = credential + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( + datasource_runtime.get_online_document_pages( + user_id=current_user.id, + datasource_parameters={}, + provider_type=datasource_runtime.datasource_provider_type(), ) - ).all() - if not data_source_bindings: - return {"notion_info": []}, 200 - pre_import_info_list = [] - for data_source_binding in data_source_bindings: - source_info = data_source_binding.source_info - pages = source_info["pages"] - # Filter out already bound pages - for page in pages: - if page["page_id"] in exist_page_ids: - page["is_bound"] = True - else: - page["is_bound"] = False - pre_import_info = { - "workspace_name": source_info["workspace_name"], - "workspace_icon": source_info["workspace_icon"], - "workspace_id": source_info["workspace_id"], - "pages": pages, - } - pre_import_info_list.append(pre_import_info) - return {"notion_info": pre_import_info_list}, 200 + ) + try: + pages = [] + workspace_info = {} + for message in online_document_result: + result = message.result + for info in result: + workspace_info = { + "workspace_id": info.workspace_id, + "workspace_name": info.workspace_name, + "workspace_icon": info.workspace_icon, + } + for page in info.pages: + page_info = { + "page_id": page.page_id, + "page_name": page.page_name, + "type": page.type, + "parent_id": page.parent_id, + "is_bound": page.page_id in exist_page_ids, + "page_icon": page.page_icon, + } + pages.append(page_info) + except Exception as e: + raise e + return {"notion_info": {**workspace_info, "pages": pages}}, 200 class DataSourceNotionApi(Resource): @@ -167,27 +202,25 @@ class DataSourceNotionApi(Resource): @login_required @account_initialization_required def get(self, workspace_id, page_id, page_type): + credential_id = request.args.get("credential_id", default=None, type=str) + if not credential_id: + raise ValueError("Credential id is required.") + datasource_provider_service = DatasourceProviderService() + credential = datasource_provider_service.get_real_credential_by_id( + tenant_id=current_user.current_tenant_id, + credential_id=credential_id, + provider="notion", + plugin_id="langgenius/notion_datasource", + ) + workspace_id = str(workspace_id) page_id = str(page_id) - with Session(db.engine) as session: - data_source_binding = session.execute( - select(DataSourceOauthBinding).where( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', - ) - ) - ).scalar_one_or_none() - if not data_source_binding: - raise NotFound("Data source binding not found.") extractor = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, - notion_access_token=data_source_binding.access_token, + notion_access_token=credential.get("integration_secret"), tenant_id=current_user.current_tenant_id, ) @@ -212,10 +245,12 @@ class DataSourceNotionApi(Resource): extract_settings = [] for notion_info in notion_info_list: workspace_id = notion_info["workspace_id"] + credential_id = notion_info.get("credential_id") for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ + "credential_id": credential_id, "notion_workspace_id": workspace_id, "notion_obj_id": page["page_id"], "notion_page_type": page["type"], diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 14db6706f6..bf4a3bac5d 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -438,10 +438,12 @@ class DatasetIndexingEstimateApi(Resource): notion_info_list = args["info_list"]["notion_info_list"] for notion_info in notion_info_list: workspace_id = notion_info["workspace_id"] + credential_id = notion_info.get("credential_id") for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ + "credential_id": credential_id, "notion_workspace_id": workspace_id, "notion_obj_id": page["page_id"], "notion_page_type": page["type"], @@ -462,6 +464,7 @@ class DatasetIndexingEstimateApi(Resource): "tenant_id": current_user.current_tenant_id, "mode": "crawl", "only_main_content": website_info_list["only_main_content"], + "credential_id": website_info_list["credential_id"], }, document_model=args["doc_form"], ) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index de38e58b11..8d880a9912 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -510,6 +510,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ + "credential_id": data_source_info["credential_id"], "notion_workspace_id": data_source_info["notion_workspace_id"], "notion_obj_id": data_source_info["notion_page_id"], "notion_page_type": data_source_info["type"], @@ -528,6 +529,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): "tenant_id": current_user.current_tenant_id, "mode": data_source_info["mode"], "only_main_content": data_source_info["only_main_content"], + "credential_id": data_source_info["credential_id"], }, document_model=document.doc_form, ) diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index fcdc91ec67..f8b1908f68 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -23,6 +23,7 @@ class WebsiteCrawlApi(Resource): ) parser.add_argument("url", type=str, required=True, nullable=True, location="json") parser.add_argument("options", type=dict, required=True, nullable=True, location="json") + parser.add_argument("credential_id", type=str, required=True, nullable=True, location="json") args = parser.parse_args() # Create typed request and validate @@ -48,6 +49,7 @@ class WebsiteCrawlStatusApi(Resource): parser.add_argument( "provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args" ) + parser.add_argument("credential_id", type=str, required=True, nullable=True, location="args") args = parser.parse_args() # Create typed request and validate diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 2387658bb6..538c3e20af 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -365,6 +365,7 @@ class IndexingRunner: extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ + "credential_id": data_source_info["credential_id"], "notion_workspace_id": data_source_info["notion_workspace_id"], "notion_obj_id": data_source_info["notion_page_id"], "notion_page_type": data_source_info["type"], @@ -391,6 +392,7 @@ class IndexingRunner: "url": data_source_info["url"], "mode": data_source_info["mode"], "only_main_content": data_source_info["only_main_content"], + "credential_id": data_source_info["credential_id"], }, document_model=dataset_document.doc_form, ) diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 1593ad1475..3e38c9153c 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -10,7 +10,7 @@ class NotionInfo(BaseModel): """ Notion import info. """ - + credential_id: Optional[str] = None notion_workspace_id: str notion_obj_id: str notion_page_type: str @@ -35,6 +35,7 @@ class WebsiteInfo(BaseModel): mode: str tenant_id: str only_main_content: bool = False + credential_id: Optional[str] = None class ExtractSetting(BaseModel): diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index bc19899ea5..44b4779493 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -171,6 +171,7 @@ class ExtractProcessor: notion_page_type=extract_setting.notion_info.notion_page_type, document_model=extract_setting.notion_info.document, tenant_id=extract_setting.notion_info.tenant_id, + credential_id=extract_setting.notion_info.credential_id, ) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: @@ -182,6 +183,7 @@ class ExtractProcessor: tenant_id=extract_setting.website_info.tenant_id, mode=extract_setting.website_info.mode, only_main_content=extract_setting.website_info.only_main_content, + credential_id=extract_setting.website_info.credential_id, ) return extractor.extract() elif extract_setting.website_info.provider == "watercrawl": @@ -191,6 +193,7 @@ class ExtractProcessor: tenant_id=extract_setting.website_info.tenant_id, mode=extract_setting.website_info.mode, only_main_content=extract_setting.website_info.only_main_content, + credential_id=extract_setting.website_info.credential_id, ) return extractor.extract() elif extract_setting.website_info.provider == "jinareader": @@ -200,6 +203,7 @@ class ExtractProcessor: tenant_id=extract_setting.website_info.tenant_id, mode=extract_setting.website_info.mode, only_main_content=extract_setting.website_info.only_main_content, + credential_id=extract_setting.website_info.credential_id, ) return extractor.extract() else: diff --git a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py index 4de8318881..cf5ede4daa 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py @@ -1,3 +1,4 @@ +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from services.website_service import WebsiteService @@ -15,19 +16,21 @@ class FirecrawlWebExtractor(BaseExtractor): only_main_content: Only return the main content of the page excluding headers, navs, footers, etc. """ - def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True): + def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True, + credential_id: Optional[str] = None): """Initialize with url, api_key, base_url and mode.""" self._url = url self.job_id = job_id self.tenant_id = tenant_id self.mode = mode self.only_main_content = only_main_content + self.credential_id = credential_id def extract(self) -> list[Document]: """Extract content from the URL.""" documents = [] if self.mode == "crawl": - crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id) + crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id, self.credential_id) if crawl_data is None: return [] document = Document( @@ -41,7 +44,7 @@ class FirecrawlWebExtractor(BaseExtractor): documents.append(document) elif self.mode == "scrape": scrape_data = WebsiteService.get_scrape_url_data( - "firecrawl", self._url, self.tenant_id, self.only_main_content + "firecrawl", self._url, self.tenant_id, self.only_main_content, self.credential_id ) document = Document( diff --git a/api/core/rag/extractor/jina_reader_extractor.py b/api/core/rag/extractor/jina_reader_extractor.py index 5b780af126..a74fb203e2 100644 --- a/api/core/rag/extractor/jina_reader_extractor.py +++ b/api/core/rag/extractor/jina_reader_extractor.py @@ -1,3 +1,4 @@ +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from services.website_service import WebsiteService @@ -8,19 +9,21 @@ class JinaReaderWebExtractor(BaseExtractor): Crawl and scrape websites and return content in clean llm-ready markdown. """ - def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False): + def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False, + credential_id: Optional[str] = None): """Initialize with url, api_key, base_url and mode.""" self._url = url self.job_id = job_id self.tenant_id = tenant_id self.mode = mode self.only_main_content = only_main_content + self.credential_id = credential_id def extract(self) -> list[Document]: """Extract content from the URL.""" documents = [] if self.mode == "crawl": - crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "jinareader", self._url, self.tenant_id) + crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "jinareader", self._url, self.tenant_id, self.credential_id) if crawl_data is None: return [] document = Document( diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 875626eb34..73bf7c81fb 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -9,7 +9,9 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Document as DocumentModel +from models.oauth import DatasourceProvider from models.source import DataSourceOauthBinding +from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) @@ -36,16 +38,18 @@ class NotionExtractor(BaseExtractor): tenant_id: str, document_model: Optional[DocumentModel] = None, notion_access_token: Optional[str] = None, + credential_id: Optional[str] = None, ): self._notion_access_token = None self._document_model = document_model self._notion_workspace_id = notion_workspace_id self._notion_obj_id = notion_obj_id self._notion_page_type = notion_page_type + self._credential_id = credential_id if notion_access_token: self._notion_access_token = notion_access_token else: - self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id) + self._notion_access_token = self._get_access_token(tenant_id, self._credential_id) if not self._notion_access_token: integration_token = dify_config.NOTION_INTEGRATION_TOKEN if integration_token is None: @@ -363,23 +367,18 @@ class NotionExtractor(BaseExtractor): return cast(str, data["last_edited_time"]) @classmethod - def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: - data_source_binding = ( - db.session.query(DataSourceOauthBinding) - .where( - db.and_( - DataSourceOauthBinding.tenant_id == tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', - ) - ) - .first() + def _get_access_token(cls, tenant_id: str, credential_id: Optional[str]) -> str: + # get credential from tenant_id and credential_id + if not credential_id: + raise Exception(f"No credential id found for tenant {tenant_id}") + datasource_provider_service = DatasourceProviderService() + credential = datasource_provider_service.get_real_credential_by_id( + tenant_id=tenant_id, + credential_id=credential_id, + provider="notion", + plugin_id="langgenius/notion_datasource", ) + if not credential: + raise Exception(f"No notion credential found for tenant {tenant_id} and credential {credential_id}") - if not data_source_binding: - raise Exception( - f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}" - ) - - return cast(str, data_source_binding.access_token) + return cast(str, credential["integration_secret"]) diff --git a/api/core/rag/extractor/watercrawl/extractor.py b/api/core/rag/extractor/watercrawl/extractor.py index 40d1740962..5559917cc5 100644 --- a/api/core/rag/extractor/watercrawl/extractor.py +++ b/api/core/rag/extractor/watercrawl/extractor.py @@ -1,3 +1,4 @@ +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from services.website_service import WebsiteService @@ -16,19 +17,21 @@ class WaterCrawlWebExtractor(BaseExtractor): only_main_content: Only return the main content of the page excluding headers, navs, footers, etc. """ - def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True): + def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True, + credential_id: Optional[str] = None): """Initialize with url, api_key, base_url and mode.""" self._url = url self.job_id = job_id self.tenant_id = tenant_id self.mode = mode self.only_main_content = only_main_content + self.credential_id = credential_id def extract(self) -> list[Document]: """Extract content from the URL.""" documents = [] if self.mode == "crawl": - crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "watercrawl", self._url, self.tenant_id) + crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "watercrawl", self._url, self.tenant_id, self.credential_id) if crawl_data is None: return [] document = Document( @@ -42,7 +45,7 @@ class WaterCrawlWebExtractor(BaseExtractor): documents.append(document) elif self.mode == "scrape": scrape_data = WebsiteService.get_scrape_url_data( - "watercrawl", self._url, self.tenant_id, self.only_main_content + "watercrawl", self._url, self.tenant_id, self.only_main_content, self.credential_id ) document = Document( diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 4b1103547d..3f85310ab2 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -56,6 +56,33 @@ class DatasourceProviderService: return {} return datasource_provider.encrypted_credentials + def get_real_credential_by_id(self, tenant_id: str, credential_id: str, provider: str, plugin_id: str) -> dict[str, Any]: + """ + get credential by id + """ + with Session(db.engine) as session: + datasource_provider = ( + session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, id=credential_id) + .first() + ) + if not datasource_provider: + return {} + encrypted_credentials = datasource_provider.encrypted_credentials + # Get provider credential secret variables + credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}", + credential_type=CredentialType.of(datasource_provider.auth_type), + ) + + # Obfuscate provider credentials + copy_credentials = encrypted_credentials.copy() + for key, value in copy_credentials.items(): + if key in credential_secret_variables: + copy_credentials[key] = encrypter.decrypt_token(tenant_id, value) + return copy_credentials + def update_datasource_provider_name( self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str ): diff --git a/api/services/website_service.py b/api/services/website_service.py index 991b669737..a0bee311e7 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -12,6 +12,7 @@ from core.rag.extractor.watercrawl.provider import WaterCrawlProvider from extensions.ext_redis import redis_client from extensions.ext_storage import storage from services.auth.api_key_auth_service import ApiKeyAuthService +from services.datasource_provider_service import DatasourceProviderService @dataclass @@ -61,6 +62,7 @@ class WebsiteCrawlApiRequest: provider: str url: str options: dict[str, Any] + credential_id: Optional[str] = None def to_crawl_request(self) -> CrawlRequest: """Convert API request to internal CrawlRequest.""" @@ -98,30 +100,48 @@ class WebsiteCrawlStatusApiRequest: provider: str job_id: str + credential_id: Optional[str] = None @classmethod def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest": """Create from Flask-RESTful parsed arguments.""" provider = args.get("provider") + credential_id = args.get("credential_id") if not provider: raise ValueError("Provider is required") if not job_id: raise ValueError("Job ID is required") - return cls(provider=provider, job_id=job_id) + return cls(provider=provider, job_id=job_id, credential_id=credential_id) class WebsiteService: """Service class for website crawling operations using different providers.""" @classmethod - def _get_credentials_and_config(cls, tenant_id: str, provider: str) -> tuple[dict, dict]: + def _get_credentials_and_config(cls, tenant_id: str, provider: str, credential_id: Optional[str] = None) -> tuple[Any, Any]: """Get and validate credentials for a provider.""" - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) - if not credentials or "config" not in credentials: - raise ValueError("No valid credentials found for the provider") - return credentials, credentials["config"] + if credential_id: + if provider == "firecrawl": + plugin_id = "langgenius/firecrawl_datasource" + elif provider == "watercrawl": + plugin_id = "langgenius/watercrawl_datasource" + elif provider == "jinareader": + plugin_id = "langgenius/jinareader_datasource" + datasource_provider_service = DatasourceProviderService() + credential = datasource_provider_service.get_real_credential_by_id( + tenant_id=tenant_id, + credential_id=credential_id, + provider=provider, + plugin_id=plugin_id, + ) + return credential.get("api_key"), credential + else: + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if not credentials or "config" not in credentials: + raise ValueError("No valid credentials found for the provider") + return credentials, credentials["config"] @classmethod def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str: @@ -144,8 +164,11 @@ class WebsiteService: """Crawl a URL using the specified provider with typed request.""" request = api_request.to_crawl_request() - _, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider) - api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config) + _, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider, api_request.credential_id) + if api_request.credential_id: + api_key = _ + else: + api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config) if request.provider == "firecrawl": return cls._crawl_with_firecrawl(request=request, api_key=api_key, config=config) @@ -227,16 +250,21 @@ class WebsiteService: return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")} @classmethod - def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]: + def get_crawl_status(cls, job_id: str, provider: str, credential_id: Optional[str] = None) -> dict[str, Any]: """Get crawl status using string parameters.""" - api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id) + api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id, credential_id=credential_id) return cls.get_crawl_status_typed(api_request) @classmethod def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]: """Get crawl status using typed request.""" - _, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider) - api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config) + _, config = cls._get_credentials_and_config(current_user.current_tenant_id, + api_request.provider, + api_request.credential_id) + if api_request.credential_id: + api_key = _ + else: + api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config) if api_request.provider == "firecrawl": return cls._get_firecrawl_status(api_request.job_id, api_key, config) @@ -309,9 +337,12 @@ class WebsiteService: return crawl_status_data @classmethod - def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None: - _, config = cls._get_credentials_and_config(tenant_id, provider) - api_key = cls._get_decrypted_api_key(tenant_id, config) + def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str, credential_id: Optional[str] = None) -> dict[str, Any] | None: + _, config = cls._get_credentials_and_config(tenant_id, provider, credential_id) + if credential_id: + api_key = _ + else: + api_key = cls._get_decrypted_api_key(tenant_id, config) if provider == "firecrawl": return cls._get_firecrawl_url_data(job_id, url, api_key, config) @@ -381,11 +412,17 @@ class WebsiteService: return None @classmethod - def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict[str, Any]: + def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool, + credential_id: Optional[str] = None) -> dict[str, Any]: request = ScrapeRequest(provider=provider, url=url, tenant_id=tenant_id, only_main_content=only_main_content) - _, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, provider=request.provider) - api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config) + _, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, + provider=request.provider, + credential_id=credential_id) + if credential_id: + api_key = _ + else: + api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config) if request.provider == "firecrawl": return cls._scrape_with_firecrawl(request=request, api_key=api_key, config=config)