add old auth transform

This commit is contained in:
jyong 2025-07-28 19:29:07 +08:00
parent b0cd4daf54
commit 829e6f0d1a
14 changed files with 390 additions and 86 deletions

View File

@ -12,11 +12,15 @@ from werkzeug.exceptions import NotFound
from configs import dify_config
from constants.languages import languages
from core.plugin.entities.plugin import DatasourceProviderID, ToolProviderID
from core.helper import encrypter
from core.plugin.entities.plugin import DatasourceProviderID, PluginInstallationSource, ToolProviderID
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.impl.plugin import PluginInstaller
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import Document
from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created
from extensions.ext_database import db
@ -29,10 +33,12 @@ from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.oauth import DatasourceOauthParamConfig
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from models.tools import ToolOAuthSystemClient
from services.account_service import AccountService, RegisterService, TenantService
from services.auth import firecrawl
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
@ -1248,3 +1254,180 @@ def setup_datasource_oauth_client(provider, client_params):
click.echo(click.style(f"plugin_id: {plugin_id}", fg="green"))
click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green"))
click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green"))
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
def transform_datasource_credentials():
"""
Transform datasource credentials
"""
try:
installer_manager = PluginInstaller()
plugin_migration = PluginMigration()
notion_plugin_id = "langgenius/notion_datasource"
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
jina_plugin_id = "langgenius/jina_datasource"
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id)
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id)
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id)
oauth_credential_type = CredentialType.OAUTH2
api_key_credential_type = CredentialType.API_KEY
# deal notion credentials
deal_notion_count = 0
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
for credential in notion_credentials:
tenant_id = credential.tenant_id
if tenant_id not in notion_credentials_tenant_mapping:
notion_credentials_tenant_mapping[tenant_id] = []
notion_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in notion_credentials_tenant_mapping.items():
# check notion plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if notion_plugin_id not in installed_plugins_ids:
if notion_plugin_unique_identifier:
# install notion plugin
installer_manager.install_from_identifiers(
tenant_id,
[notion_plugin_unique_identifier],
PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": notion_plugin_unique_identifier,
}
],
)
auth_count = 0
for credential in credentials:
auth_count += 1
# get credential oauth params
access_token = credential.access_token
# notion info
notion_info = credential.source_info
workspace_id = notion_info.get("workspace_id")
workspace_name = notion_info.get("workspace_name")
workspace_icon = notion_info.get("workspace_icon")
new_credentials = {
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
"workspace_id": workspace_id,
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
}
datasource_provider = DatasourceProvider(
provider="notion",
tenant_id=tenant_id,
plugin_id=notion_plugin_id,
auth_type=oauth_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url=workspace_icon or "default",
is_default=False,
)
db.session.add(datasource_provider)
deal_notion_count += 1
db.session.commit()
# deal firecrawl credentials
deal_firecrawl_count = 0
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for credential in firecrawl_credentials:
tenant_id = credential.tenant_id
if tenant_id not in firecrawl_credentials_tenant_mapping:
firecrawl_credentials_tenant_mapping[tenant_id] = []
firecrawl_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in firecrawl_credentials_tenant_mapping.items():
# check firecrawl plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if firecrawl_plugin_id not in installed_plugins_ids:
if firecrawl_plugin_unique_identifier:
# install firecrawl plugin
installer_manager.install_from_identifiers(
tenant_id,
[firecrawl_plugin_unique_identifier],
PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": firecrawl_plugin_unique_identifier,
}
],
)
auth_count = 0
for credential in credentials:
auth_count += 1
# get credential api key
api_key = credential.credentials.get("config", {}).get("api_key")
base_url = credential.credentials.get("config", {}).get("base_url")
new_credentials = {
"firecrawl_api_key": api_key,
"base_url": base_url,
}
datasource_provider = DatasourceProvider(
provider="firecrawl",
tenant_id=tenant_id,
plugin_id=firecrawl_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
db.session.add(datasource_provider)
deal_firecrawl_count += 1
db.session.commit()
# deal jina credentials
deal_jina_count = 0
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jina").all()
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for credential in jina_credentials:
tenant_id = credential.tenant_id
if tenant_id not in jina_credentials_tenant_mapping:
jina_credentials_tenant_mapping[tenant_id] = []
jina_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in jina_credentials_tenant_mapping.items():
# check jina plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if jina_plugin_id not in installed_plugins_ids:
if jina_plugin_unique_identifier:
# install jina plugin
installer_manager.install_from_identifiers(
tenant_id,
[jina_plugin_unique_identifier],
PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": jina_plugin_unique_identifier,
}
],
)
auth_count = 0
for credential in credentials:
auth_count += 1
# get credential api key
api_key = credential.credentials.get("config", {}).get("api_key")
new_credentials = {
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jina",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
db.session.add(datasource_provider)
deal_jina_count += 1
db.session.commit()
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green"))
click.echo(click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green"))
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))

View File

@ -1,4 +1,5 @@
import json
from typing import Generator, cast
from flask import request
from flask_login import current_user
@ -9,6 +10,8 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
@ -17,7 +20,9 @@ from fields.data_source_fields import integrate_list_fields, integrate_notion_in
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import DataSourceOauthBinding, Document
from models.oauth import DatasourceProvider
from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task
@ -112,6 +117,18 @@ class DataSourceNotionListApi(Resource):
@marshal_with(integrate_notion_info_list_fields)
def get(self):
dataset_id = request.args.get("dataset_id", default=None, type=str)
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_real_credential_by_id(
tenant_id=current_user.current_tenant_id,
credential_id=credential_id,
provider="notion",
plugin_id="langgenius/notion_datasource",
)
if not credential:
raise NotFound("Credential not found.")
exist_page_ids = []
with Session(db.engine) as session:
# import notion in the exist dataset
@ -135,31 +152,49 @@ class DataSourceNotionListApi(Resource):
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
data_source_bindings = session.scalars(
select(DataSourceOauthBinding).filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion",
datasource_name="notion",
tenant_id=current_user.current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
datasource_provider_service = DatasourceProviderService()
if credential:
datasource_runtime.runtime.credentials = credential
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters={},
provider_type=datasource_runtime.datasource_provider_type(),
)
).all()
if not data_source_bindings:
return {"notion_info": []}, 200
pre_import_info_list = []
for data_source_binding in data_source_bindings:
source_info = data_source_binding.source_info
pages = source_info["pages"]
# Filter out already bound pages
for page in pages:
if page["page_id"] in exist_page_ids:
page["is_bound"] = True
else:
page["is_bound"] = False
pre_import_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
}
pre_import_info_list.append(pre_import_info)
return {"notion_info": pre_import_info_list}, 200
)
try:
pages = []
workspace_info = {}
for message in online_document_result:
result = message.result
for info in result:
workspace_info = {
"workspace_id": info.workspace_id,
"workspace_name": info.workspace_name,
"workspace_icon": info.workspace_icon,
}
for page in info.pages:
page_info = {
"page_id": page.page_id,
"page_name": page.page_name,
"type": page.type,
"parent_id": page.parent_id,
"is_bound": page.page_id in exist_page_ids,
"page_icon": page.page_icon,
}
pages.append(page_info)
except Exception as e:
raise e
return {"notion_info": {**workspace_info, "pages": pages}}, 200
class DataSourceNotionApi(Resource):
@ -167,27 +202,25 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def get(self, workspace_id, page_id, page_type):
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_real_credential_by_id(
tenant_id=current_user.current_tenant_id,
credential_id=credential_id,
provider="notion",
plugin_id="langgenius/notion_datasource",
)
workspace_id = str(workspace_id)
page_id = str(page_id)
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).scalar_one_or_none()
if not data_source_binding:
raise NotFound("Data source binding not found.")
extractor = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=data_source_binding.access_token,
notion_access_token=credential.get("integration_secret"),
tenant_id=current_user.current_tenant_id,
)
@ -212,10 +245,12 @@ class DataSourceNotionApi(Resource):
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],

View File

@ -438,10 +438,12 @@ class DatasetIndexingEstimateApi(Resource):
notion_info_list = args["info_list"]["notion_info_list"]
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
@ -462,6 +464,7 @@ class DatasetIndexingEstimateApi(Resource):
"tenant_id": current_user.current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
"credential_id": website_info_list["credential_id"],
},
document_model=args["doc_form"],
)

View File

@ -510,6 +510,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
@ -528,6 +529,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
"tenant_id": current_user.current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
"credential_id": data_source_info["credential_id"],
},
document_model=document.doc_form,
)

View File

@ -23,6 +23,7 @@ class WebsiteCrawlApi(Resource):
)
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=True, location="json")
args = parser.parse_args()
# Create typed request and validate
@ -48,6 +49,7 @@ class WebsiteCrawlStatusApi(Resource):
parser.add_argument(
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
)
parser.add_argument("credential_id", type=str, required=True, nullable=True, location="args")
args = parser.parse_args()
# Create typed request and validate

View File

@ -365,6 +365,7 @@ class IndexingRunner:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
@ -391,6 +392,7 @@ class IndexingRunner:
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
"credential_id": data_source_info["credential_id"],
},
document_model=dataset_document.doc_form,
)

View File

@ -10,7 +10,7 @@ class NotionInfo(BaseModel):
"""
Notion import info.
"""
credential_id: Optional[str] = None
notion_workspace_id: str
notion_obj_id: str
notion_page_type: str
@ -35,6 +35,7 @@ class WebsiteInfo(BaseModel):
mode: str
tenant_id: str
only_main_content: bool = False
credential_id: Optional[str] = None
class ExtractSetting(BaseModel):

View File

@ -171,6 +171,7 @@ class ExtractProcessor:
notion_page_type=extract_setting.notion_info.notion_page_type,
document_model=extract_setting.notion_info.document,
tenant_id=extract_setting.notion_info.tenant_id,
credential_id=extract_setting.notion_info.credential_id,
)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
@ -182,6 +183,7 @@ class ExtractProcessor:
tenant_id=extract_setting.website_info.tenant_id,
mode=extract_setting.website_info.mode,
only_main_content=extract_setting.website_info.only_main_content,
credential_id=extract_setting.website_info.credential_id,
)
return extractor.extract()
elif extract_setting.website_info.provider == "watercrawl":
@ -191,6 +193,7 @@ class ExtractProcessor:
tenant_id=extract_setting.website_info.tenant_id,
mode=extract_setting.website_info.mode,
only_main_content=extract_setting.website_info.only_main_content,
credential_id=extract_setting.website_info.credential_id,
)
return extractor.extract()
elif extract_setting.website_info.provider == "jinareader":
@ -200,6 +203,7 @@ class ExtractProcessor:
tenant_id=extract_setting.website_info.tenant_id,
mode=extract_setting.website_info.mode,
only_main_content=extract_setting.website_info.only_main_content,
credential_id=extract_setting.website_info.credential_id,
)
return extractor.extract()
else:

View File

@ -1,3 +1,4 @@
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from services.website_service import WebsiteService
@ -15,19 +16,21 @@ class FirecrawlWebExtractor(BaseExtractor):
only_main_content: Only return the main content of the page excluding headers, navs, footers, etc.
"""
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True):
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True,
credential_id: Optional[str] = None):
"""Initialize with url, api_key, base_url and mode."""
self._url = url
self.job_id = job_id
self.tenant_id = tenant_id
self.mode = mode
self.only_main_content = only_main_content
self.credential_id = credential_id
def extract(self) -> list[Document]:
"""Extract content from the URL."""
documents = []
if self.mode == "crawl":
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id)
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id, self.credential_id)
if crawl_data is None:
return []
document = Document(
@ -41,7 +44,7 @@ class FirecrawlWebExtractor(BaseExtractor):
documents.append(document)
elif self.mode == "scrape":
scrape_data = WebsiteService.get_scrape_url_data(
"firecrawl", self._url, self.tenant_id, self.only_main_content
"firecrawl", self._url, self.tenant_id, self.only_main_content, self.credential_id
)
document = Document(

View File

@ -1,3 +1,4 @@
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from services.website_service import WebsiteService
@ -8,19 +9,21 @@ class JinaReaderWebExtractor(BaseExtractor):
Crawl and scrape websites and return content in clean llm-ready markdown.
"""
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False):
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False,
credential_id: Optional[str] = None):
"""Initialize with url, api_key, base_url and mode."""
self._url = url
self.job_id = job_id
self.tenant_id = tenant_id
self.mode = mode
self.only_main_content = only_main_content
self.credential_id = credential_id
def extract(self) -> list[Document]:
"""Extract content from the URL."""
documents = []
if self.mode == "crawl":
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "jinareader", self._url, self.tenant_id)
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "jinareader", self._url, self.tenant_id, self.credential_id)
if crawl_data is None:
return []
document = Document(

View File

@ -9,7 +9,9 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Document as DocumentModel
from models.oauth import DatasourceProvider
from models.source import DataSourceOauthBinding
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
@ -36,16 +38,18 @@ class NotionExtractor(BaseExtractor):
tenant_id: str,
document_model: Optional[DocumentModel] = None,
notion_access_token: Optional[str] = None,
credential_id: Optional[str] = None,
):
self._notion_access_token = None
self._document_model = document_model
self._notion_workspace_id = notion_workspace_id
self._notion_obj_id = notion_obj_id
self._notion_page_type = notion_page_type
self._credential_id = credential_id
if notion_access_token:
self._notion_access_token = notion_access_token
else:
self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id)
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
if not self._notion_access_token:
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
if integration_token is None:
@ -363,23 +367,18 @@ class NotionExtractor(BaseExtractor):
return cast(str, data["last_edited_time"])
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
)
)
.first()
def _get_access_token(cls, tenant_id: str, credential_id: Optional[str]) -> str:
# get credential from tenant_id and credential_id
if not credential_id:
raise Exception(f"No credential id found for tenant {tenant_id}")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_real_credential_by_id(
tenant_id=tenant_id,
credential_id=credential_id,
provider="notion",
plugin_id="langgenius/notion_datasource",
)
if not credential:
raise Exception(f"No notion credential found for tenant {tenant_id} and credential {credential_id}")
if not data_source_binding:
raise Exception(
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
)
return cast(str, data_source_binding.access_token)
return cast(str, credential["integration_secret"])

View File

@ -1,3 +1,4 @@
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from services.website_service import WebsiteService
@ -16,19 +17,21 @@ class WaterCrawlWebExtractor(BaseExtractor):
only_main_content: Only return the main content of the page excluding headers, navs, footers, etc.
"""
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True):
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True,
credential_id: Optional[str] = None):
"""Initialize with url, api_key, base_url and mode."""
self._url = url
self.job_id = job_id
self.tenant_id = tenant_id
self.mode = mode
self.only_main_content = only_main_content
self.credential_id = credential_id
def extract(self) -> list[Document]:
"""Extract content from the URL."""
documents = []
if self.mode == "crawl":
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "watercrawl", self._url, self.tenant_id)
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "watercrawl", self._url, self.tenant_id, self.credential_id)
if crawl_data is None:
return []
document = Document(
@ -42,7 +45,7 @@ class WaterCrawlWebExtractor(BaseExtractor):
documents.append(document)
elif self.mode == "scrape":
scrape_data = WebsiteService.get_scrape_url_data(
"watercrawl", self._url, self.tenant_id, self.only_main_content
"watercrawl", self._url, self.tenant_id, self.only_main_content, self.credential_id
)
document = Document(

View File

@ -56,6 +56,33 @@ class DatasourceProviderService:
return {}
return datasource_provider.encrypted_credentials
def get_real_credential_by_id(self, tenant_id: str, credential_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
"""
get credential by id
"""
with Session(db.engine) as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=credential_id)
.first()
)
if not datasource_provider:
return {}
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
return copy_credentials
def update_datasource_provider_name(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
):

View File

@ -12,6 +12,7 @@ from core.rag.extractor.watercrawl.provider import WaterCrawlProvider
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from services.auth.api_key_auth_service import ApiKeyAuthService
from services.datasource_provider_service import DatasourceProviderService
@dataclass
@ -61,6 +62,7 @@ class WebsiteCrawlApiRequest:
provider: str
url: str
options: dict[str, Any]
credential_id: Optional[str] = None
def to_crawl_request(self) -> CrawlRequest:
"""Convert API request to internal CrawlRequest."""
@ -98,30 +100,48 @@ class WebsiteCrawlStatusApiRequest:
provider: str
job_id: str
credential_id: Optional[str] = None
@classmethod
def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest":
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
credential_id = args.get("credential_id")
if not provider:
raise ValueError("Provider is required")
if not job_id:
raise ValueError("Job ID is required")
return cls(provider=provider, job_id=job_id)
return cls(provider=provider, job_id=job_id, credential_id=credential_id)
class WebsiteService:
"""Service class for website crawling operations using different providers."""
@classmethod
def _get_credentials_and_config(cls, tenant_id: str, provider: str) -> tuple[dict, dict]:
def _get_credentials_and_config(cls, tenant_id: str, provider: str, credential_id: Optional[str] = None) -> tuple[Any, Any]:
"""Get and validate credentials for a provider."""
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
if not credentials or "config" not in credentials:
raise ValueError("No valid credentials found for the provider")
return credentials, credentials["config"]
if credential_id:
if provider == "firecrawl":
plugin_id = "langgenius/firecrawl_datasource"
elif provider == "watercrawl":
plugin_id = "langgenius/watercrawl_datasource"
elif provider == "jinareader":
plugin_id = "langgenius/jinareader_datasource"
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_real_credential_by_id(
tenant_id=tenant_id,
credential_id=credential_id,
provider=provider,
plugin_id=plugin_id,
)
return credential.get("api_key"), credential
else:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
if not credentials or "config" not in credentials:
raise ValueError("No valid credentials found for the provider")
return credentials, credentials["config"]
@classmethod
def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str:
@ -144,8 +164,11 @@ class WebsiteService:
"""Crawl a URL using the specified provider with typed request."""
request = api_request.to_crawl_request()
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider)
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider, api_request.credential_id)
if api_request.credential_id:
api_key = _
else:
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
if request.provider == "firecrawl":
return cls._crawl_with_firecrawl(request=request, api_key=api_key, config=config)
@ -227,16 +250,21 @@ class WebsiteService:
return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
@classmethod
def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]:
def get_crawl_status(cls, job_id: str, provider: str, credential_id: Optional[str] = None) -> dict[str, Any]:
"""Get crawl status using string parameters."""
api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id)
api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id, credential_id=credential_id)
return cls.get_crawl_status_typed(api_request)
@classmethod
def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]:
"""Get crawl status using typed request."""
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider)
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
_, config = cls._get_credentials_and_config(current_user.current_tenant_id,
api_request.provider,
api_request.credential_id)
if api_request.credential_id:
api_key = _
else:
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
if api_request.provider == "firecrawl":
return cls._get_firecrawl_status(api_request.job_id, api_key, config)
@ -309,9 +337,12 @@ class WebsiteService:
return crawl_status_data
@classmethod
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None:
_, config = cls._get_credentials_and_config(tenant_id, provider)
api_key = cls._get_decrypted_api_key(tenant_id, config)
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str, credential_id: Optional[str] = None) -> dict[str, Any] | None:
_, config = cls._get_credentials_and_config(tenant_id, provider, credential_id)
if credential_id:
api_key = _
else:
api_key = cls._get_decrypted_api_key(tenant_id, config)
if provider == "firecrawl":
return cls._get_firecrawl_url_data(job_id, url, api_key, config)
@ -381,11 +412,17 @@ class WebsiteService:
return None
@classmethod
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict[str, Any]:
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool,
credential_id: Optional[str] = None) -> dict[str, Any]:
request = ScrapeRequest(provider=provider, url=url, tenant_id=tenant_id, only_main_content=only_main_content)
_, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, provider=request.provider)
api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config)
_, config = cls._get_credentials_and_config(tenant_id=request.tenant_id,
provider=request.provider,
credential_id=credential_id)
if credential_id:
api_key = _
else:
api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config)
if request.provider == "firecrawl":
return cls._scrape_with_firecrawl(request=request, api_key=api_key, config=config)