mirror of https://github.com/langgenius/dify.git
add old auth transform
This commit is contained in:
parent
b0cd4daf54
commit
829e6f0d1a
187
api/commands.py
187
api/commands.py
|
|
@ -12,11 +12,15 @@ from werkzeug.exceptions import NotFound
|
|||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.plugin.entities.plugin import DatasourceProviderID, ToolProviderID
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin import DatasourceProviderID, PluginInstallationSource, ToolProviderID
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -29,10 +33,12 @@ from models import Tenant
|
|||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.oauth import DatasourceOauthParamConfig
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.auth import firecrawl
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
|
|
@ -1248,3 +1254,180 @@ def setup_datasource_oauth_client(provider, client_params):
|
|||
click.echo(click.style(f"plugin_id: {plugin_id}", fg="green"))
|
||||
click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green"))
|
||||
click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
||||
def transform_datasource_credentials():
|
||||
"""
|
||||
Transform datasource credentials
|
||||
"""
|
||||
try:
|
||||
installer_manager = PluginInstaller()
|
||||
plugin_migration = PluginMigration()
|
||||
|
||||
notion_plugin_id = "langgenius/notion_datasource"
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id)
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id)
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id)
|
||||
oauth_credential_type = CredentialType.OAUTH2
|
||||
api_key_credential_type = CredentialType.API_KEY
|
||||
|
||||
|
||||
# deal notion credentials
|
||||
deal_notion_count = 0
|
||||
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
|
||||
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
|
||||
for credential in notion_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
if tenant_id not in notion_credentials_tenant_mapping:
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in notion_credentials_tenant_mapping.items():
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
installer_manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
[notion_plugin_unique_identifier],
|
||||
PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": notion_plugin_unique_identifier,
|
||||
}
|
||||
],
|
||||
)
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = credential.access_token
|
||||
# notion info
|
||||
notion_info = credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
|
||||
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for credential in firecrawl_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
if tenant_id not in firecrawl_credentials_tenant_mapping:
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
installer_manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
[firecrawl_plugin_unique_identifier],
|
||||
PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": firecrawl_plugin_unique_identifier,
|
||||
}
|
||||
],
|
||||
)
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
api_key = credential.credentials.get("config", {}).get("api_key")
|
||||
base_url = credential.credentials.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jina").all()
|
||||
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for credential in jina_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
if tenant_id not in jina_credentials_tenant_mapping:
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in jina_credentials_tenant_mapping.items():
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
installer_manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
[jina_plugin_unique_identifier],
|
||||
PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": jina_plugin_unique_identifier,
|
||||
}
|
||||
],
|
||||
)
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
api_key = credential.credentials.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green"))
|
||||
click.echo(click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green"))
|
||||
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
from typing import Generator, cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
|
|
@ -9,6 +10,8 @@ from werkzeug.exceptions import NotFound
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
|
|
@ -17,7 +20,9 @@ from fields.data_source_fields import integrate_list_fields, integrate_notion_in
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from models.oauth import DatasourceProvider
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
|
|
@ -112,6 +117,18 @@ class DataSourceNotionListApi(Resource):
|
|||
@marshal_with(integrate_notion_info_list_fields)
|
||||
def get(self):
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_real_credential_by_id(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
if not credential:
|
||||
raise NotFound("Credential not found.")
|
||||
exist_page_ids = []
|
||||
with Session(db.engine) as session:
|
||||
# import notion in the exist dataset
|
||||
|
|
@ -135,31 +152,49 @@ class DataSourceNotionListApi(Resource):
|
|||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
data_source_bindings = session.scalars(
|
||||
select(DataSourceOauthBinding).filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id="langgenius/notion_datasource/notion",
|
||||
datasource_name="notion",
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
if credential:
|
||||
datasource_runtime.runtime.credentials = credential
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_pages(
|
||||
user_id=current_user.id,
|
||||
datasource_parameters={},
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {"notion_info": []}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info["pages"]
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page["page_id"] in exist_page_ids:
|
||||
page["is_bound"] = True
|
||||
else:
|
||||
page["is_bound"] = False
|
||||
pre_import_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {"notion_info": pre_import_info_list}, 200
|
||||
)
|
||||
try:
|
||||
pages = []
|
||||
workspace_info = {}
|
||||
for message in online_document_result:
|
||||
result = message.result
|
||||
for info in result:
|
||||
workspace_info = {
|
||||
"workspace_id": info.workspace_id,
|
||||
"workspace_name": info.workspace_name,
|
||||
"workspace_icon": info.workspace_icon,
|
||||
}
|
||||
for page in info.pages:
|
||||
page_info = {
|
||||
"page_id": page.page_id,
|
||||
"page_name": page.page_name,
|
||||
"type": page.type,
|
||||
"parent_id": page.parent_id,
|
||||
"is_bound": page.page_id in exist_page_ids,
|
||||
"page_icon": page.page_icon,
|
||||
}
|
||||
pages.append(page_info)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return {"notion_info": {**workspace_info, "pages": pages}}, 200
|
||||
|
||||
|
||||
class DataSourceNotionApi(Resource):
|
||||
|
|
@ -167,27 +202,25 @@ class DataSourceNotionApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_real_credential_by_id(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).where(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if not data_source_binding:
|
||||
raise NotFound("Data source binding not found.")
|
||||
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
|
||||
|
|
@ -212,10 +245,12 @@ class DataSourceNotionApi(Resource):
|
|||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
|
|
|
|||
|
|
@ -438,10 +438,12 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
notion_info_list = args["info_list"]["notion_info_list"]
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
|
|
@ -462,6 +464,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
"credential_id": website_info_list["credential_id"],
|
||||
},
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -510,6 +510,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
|
|
@ -528,6 +529,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
},
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ class WebsiteCrawlApi(Resource):
|
|||
)
|
||||
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create typed request and validate
|
||||
|
|
@ -48,6 +49,7 @@ class WebsiteCrawlStatusApi(Resource):
|
|||
parser.add_argument(
|
||||
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
|
||||
)
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create typed request and validate
|
||||
|
|
|
|||
|
|
@ -365,6 +365,7 @@ class IndexingRunner:
|
|||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
|
|
@ -391,6 +392,7 @@ class IndexingRunner:
|
|||
"url": data_source_info["url"],
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
},
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class NotionInfo(BaseModel):
|
|||
"""
|
||||
Notion import info.
|
||||
"""
|
||||
|
||||
credential_id: Optional[str] = None
|
||||
notion_workspace_id: str
|
||||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
|
|
@ -35,6 +35,7 @@ class WebsiteInfo(BaseModel):
|
|||
mode: str
|
||||
tenant_id: str
|
||||
only_main_content: bool = False
|
||||
credential_id: Optional[str] = None
|
||||
|
||||
|
||||
class ExtractSetting(BaseModel):
|
||||
|
|
|
|||
|
|
@ -171,6 +171,7 @@ class ExtractProcessor:
|
|||
notion_page_type=extract_setting.notion_info.notion_page_type,
|
||||
document_model=extract_setting.notion_info.document,
|
||||
tenant_id=extract_setting.notion_info.tenant_id,
|
||||
credential_id=extract_setting.notion_info.credential_id,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
|
||||
|
|
@ -182,6 +183,7 @@ class ExtractProcessor:
|
|||
tenant_id=extract_setting.website_info.tenant_id,
|
||||
mode=extract_setting.website_info.mode,
|
||||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
credential_id=extract_setting.website_info.credential_id,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.website_info.provider == "watercrawl":
|
||||
|
|
@ -191,6 +193,7 @@ class ExtractProcessor:
|
|||
tenant_id=extract_setting.website_info.tenant_id,
|
||||
mode=extract_setting.website_info.mode,
|
||||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
credential_id=extract_setting.website_info.credential_id,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.website_info.provider == "jinareader":
|
||||
|
|
@ -200,6 +203,7 @@ class ExtractProcessor:
|
|||
tenant_id=extract_setting.website_info.tenant_id,
|
||||
mode=extract_setting.website_info.mode,
|
||||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
credential_id=extract_setting.website_info.credential_id,
|
||||
)
|
||||
return extractor.extract()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from typing import Optional
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from services.website_service import WebsiteService
|
||||
|
|
@ -15,19 +16,21 @@ class FirecrawlWebExtractor(BaseExtractor):
|
|||
only_main_content: Only return the main content of the page excluding headers, navs, footers, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True):
|
||||
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True,
|
||||
credential_id: Optional[str] = None):
|
||||
"""Initialize with url, api_key, base_url and mode."""
|
||||
self._url = url
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.mode = mode
|
||||
self.only_main_content = only_main_content
|
||||
self.credential_id = credential_id
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Extract content from the URL."""
|
||||
documents = []
|
||||
if self.mode == "crawl":
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id)
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id, self.credential_id)
|
||||
if crawl_data is None:
|
||||
return []
|
||||
document = Document(
|
||||
|
|
@ -41,7 +44,7 @@ class FirecrawlWebExtractor(BaseExtractor):
|
|||
documents.append(document)
|
||||
elif self.mode == "scrape":
|
||||
scrape_data = WebsiteService.get_scrape_url_data(
|
||||
"firecrawl", self._url, self.tenant_id, self.only_main_content
|
||||
"firecrawl", self._url, self.tenant_id, self.only_main_content, self.credential_id
|
||||
)
|
||||
|
||||
document = Document(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from typing import Optional
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from services.website_service import WebsiteService
|
||||
|
|
@ -8,19 +9,21 @@ class JinaReaderWebExtractor(BaseExtractor):
|
|||
Crawl and scrape websites and return content in clean llm-ready markdown.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False):
|
||||
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False,
|
||||
credential_id: Optional[str] = None):
|
||||
"""Initialize with url, api_key, base_url and mode."""
|
||||
self._url = url
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.mode = mode
|
||||
self.only_main_content = only_main_content
|
||||
self.credential_id = credential_id
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Extract content from the URL."""
|
||||
documents = []
|
||||
if self.mode == "crawl":
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "jinareader", self._url, self.tenant_id)
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "jinareader", self._url, self.tenant_id, self.credential_id)
|
||||
if crawl_data is None:
|
||||
return []
|
||||
document = Document(
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
|||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document as DocumentModel
|
||||
from models.oauth import DatasourceProvider
|
||||
from models.source import DataSourceOauthBinding
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -36,16 +38,18 @@ class NotionExtractor(BaseExtractor):
|
|||
tenant_id: str,
|
||||
document_model: Optional[DocumentModel] = None,
|
||||
notion_access_token: Optional[str] = None,
|
||||
credential_id: Optional[str] = None,
|
||||
):
|
||||
self._notion_access_token = None
|
||||
self._document_model = document_model
|
||||
self._notion_workspace_id = notion_workspace_id
|
||||
self._notion_obj_id = notion_obj_id
|
||||
self._notion_page_type = notion_page_type
|
||||
self._credential_id = credential_id
|
||||
if notion_access_token:
|
||||
self._notion_access_token = notion_access_token
|
||||
else:
|
||||
self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id)
|
||||
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
|
||||
if not self._notion_access_token:
|
||||
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
|
||||
if integration_token is None:
|
||||
|
|
@ -363,23 +367,18 @@ class NotionExtractor(BaseExtractor):
|
|||
return cast(str, data["last_edited_time"])
|
||||
|
||||
@classmethod
|
||||
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.where(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
|
||||
)
|
||||
)
|
||||
.first()
|
||||
def _get_access_token(cls, tenant_id: str, credential_id: Optional[str]) -> str:
|
||||
# get credential from tenant_id and credential_id
|
||||
if not credential_id:
|
||||
raise Exception(f"No credential id found for tenant {tenant_id}")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_real_credential_by_id(
|
||||
tenant_id=tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
if not credential:
|
||||
raise Exception(f"No notion credential found for tenant {tenant_id} and credential {credential_id}")
|
||||
|
||||
if not data_source_binding:
|
||||
raise Exception(
|
||||
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
|
||||
)
|
||||
|
||||
return cast(str, data_source_binding.access_token)
|
||||
return cast(str, credential["integration_secret"])
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from typing import Optional
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from services.website_service import WebsiteService
|
||||
|
|
@ -16,19 +17,21 @@ class WaterCrawlWebExtractor(BaseExtractor):
|
|||
only_main_content: Only return the main content of the page excluding headers, navs, footers, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True):
|
||||
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True,
|
||||
credential_id: Optional[str] = None):
|
||||
"""Initialize with url, api_key, base_url and mode."""
|
||||
self._url = url
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.mode = mode
|
||||
self.only_main_content = only_main_content
|
||||
self.credential_id = credential_id
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Extract content from the URL."""
|
||||
documents = []
|
||||
if self.mode == "crawl":
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "watercrawl", self._url, self.tenant_id)
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "watercrawl", self._url, self.tenant_id, self.credential_id)
|
||||
if crawl_data is None:
|
||||
return []
|
||||
document = Document(
|
||||
|
|
@ -42,7 +45,7 @@ class WaterCrawlWebExtractor(BaseExtractor):
|
|||
documents.append(document)
|
||||
elif self.mode == "scrape":
|
||||
scrape_data = WebsiteService.get_scrape_url_data(
|
||||
"watercrawl", self._url, self.tenant_id, self.only_main_content
|
||||
"watercrawl", self._url, self.tenant_id, self.only_main_content, self.credential_id
|
||||
)
|
||||
|
||||
document = Document(
|
||||
|
|
|
|||
|
|
@ -56,6 +56,33 @@ class DatasourceProviderService:
|
|||
return {}
|
||||
return datasource_provider.encrypted_credentials
|
||||
|
||||
def get_real_credential_by_id(self, tenant_id: str, credential_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
get credential by id
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, id=credential_id)
|
||||
.first()
|
||||
)
|
||||
if not datasource_provider:
|
||||
return {}
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=CredentialType.of(datasource_provider.auth_type),
|
||||
)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = encrypted_credentials.copy()
|
||||
for key, value in copy_credentials.items():
|
||||
if key in credential_secret_variables:
|
||||
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
|
||||
return copy_credentials
|
||||
|
||||
def update_datasource_provider_name(
|
||||
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
|
||||
):
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from core.rag.extractor.watercrawl.provider import WaterCrawlProvider
|
|||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -61,6 +62,7 @@ class WebsiteCrawlApiRequest:
|
|||
provider: str
|
||||
url: str
|
||||
options: dict[str, Any]
|
||||
credential_id: Optional[str] = None
|
||||
|
||||
def to_crawl_request(self) -> CrawlRequest:
|
||||
"""Convert API request to internal CrawlRequest."""
|
||||
|
|
@ -98,30 +100,48 @@ class WebsiteCrawlStatusApiRequest:
|
|||
|
||||
provider: str
|
||||
job_id: str
|
||||
credential_id: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest":
|
||||
"""Create from Flask-RESTful parsed arguments."""
|
||||
provider = args.get("provider")
|
||||
credential_id = args.get("credential_id")
|
||||
|
||||
if not provider:
|
||||
raise ValueError("Provider is required")
|
||||
if not job_id:
|
||||
raise ValueError("Job ID is required")
|
||||
|
||||
return cls(provider=provider, job_id=job_id)
|
||||
return cls(provider=provider, job_id=job_id, credential_id=credential_id)
|
||||
|
||||
|
||||
class WebsiteService:
|
||||
"""Service class for website crawling operations using different providers."""
|
||||
|
||||
@classmethod
|
||||
def _get_credentials_and_config(cls, tenant_id: str, provider: str) -> tuple[dict, dict]:
|
||||
def _get_credentials_and_config(cls, tenant_id: str, provider: str, credential_id: Optional[str] = None) -> tuple[Any, Any]:
|
||||
"""Get and validate credentials for a provider."""
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
if not credentials or "config" not in credentials:
|
||||
raise ValueError("No valid credentials found for the provider")
|
||||
return credentials, credentials["config"]
|
||||
if credential_id:
|
||||
if provider == "firecrawl":
|
||||
plugin_id = "langgenius/firecrawl_datasource"
|
||||
elif provider == "watercrawl":
|
||||
plugin_id = "langgenius/watercrawl_datasource"
|
||||
elif provider == "jinareader":
|
||||
plugin_id = "langgenius/jinareader_datasource"
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_real_credential_by_id(
|
||||
tenant_id=tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
return credential.get("api_key"), credential
|
||||
else:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
if not credentials or "config" not in credentials:
|
||||
raise ValueError("No valid credentials found for the provider")
|
||||
return credentials, credentials["config"]
|
||||
|
||||
@classmethod
|
||||
def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str:
|
||||
|
|
@ -144,8 +164,11 @@ class WebsiteService:
|
|||
"""Crawl a URL using the specified provider with typed request."""
|
||||
request = api_request.to_crawl_request()
|
||||
|
||||
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider)
|
||||
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
|
||||
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider, api_request.credential_id)
|
||||
if api_request.credential_id:
|
||||
api_key = _
|
||||
else:
|
||||
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
|
||||
|
||||
if request.provider == "firecrawl":
|
||||
return cls._crawl_with_firecrawl(request=request, api_key=api_key, config=config)
|
||||
|
|
@ -227,16 +250,21 @@ class WebsiteService:
|
|||
return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]:
|
||||
def get_crawl_status(cls, job_id: str, provider: str, credential_id: Optional[str] = None) -> dict[str, Any]:
|
||||
"""Get crawl status using string parameters."""
|
||||
api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id)
|
||||
api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id, credential_id=credential_id)
|
||||
return cls.get_crawl_status_typed(api_request)
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]:
|
||||
"""Get crawl status using typed request."""
|
||||
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider)
|
||||
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
|
||||
_, config = cls._get_credentials_and_config(current_user.current_tenant_id,
|
||||
api_request.provider,
|
||||
api_request.credential_id)
|
||||
if api_request.credential_id:
|
||||
api_key = _
|
||||
else:
|
||||
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
|
||||
|
||||
if api_request.provider == "firecrawl":
|
||||
return cls._get_firecrawl_status(api_request.job_id, api_key, config)
|
||||
|
|
@ -309,9 +337,12 @@ class WebsiteService:
|
|||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None:
|
||||
_, config = cls._get_credentials_and_config(tenant_id, provider)
|
||||
api_key = cls._get_decrypted_api_key(tenant_id, config)
|
||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str, credential_id: Optional[str] = None) -> dict[str, Any] | None:
|
||||
_, config = cls._get_credentials_and_config(tenant_id, provider, credential_id)
|
||||
if credential_id:
|
||||
api_key = _
|
||||
else:
|
||||
api_key = cls._get_decrypted_api_key(tenant_id, config)
|
||||
|
||||
if provider == "firecrawl":
|
||||
return cls._get_firecrawl_url_data(job_id, url, api_key, config)
|
||||
|
|
@ -381,11 +412,17 @@ class WebsiteService:
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict[str, Any]:
|
||||
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool,
|
||||
credential_id: Optional[str] = None) -> dict[str, Any]:
|
||||
request = ScrapeRequest(provider=provider, url=url, tenant_id=tenant_id, only_main_content=only_main_content)
|
||||
|
||||
_, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, provider=request.provider)
|
||||
api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config)
|
||||
_, config = cls._get_credentials_and_config(tenant_id=request.tenant_id,
|
||||
provider=request.provider,
|
||||
credential_id=credential_id)
|
||||
if credential_id:
|
||||
api_key = _
|
||||
else:
|
||||
api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config)
|
||||
|
||||
if request.provider == "firecrawl":
|
||||
return cls._scrape_with_firecrawl(request=request, api_key=api_key, config=config)
|
||||
|
|
|
|||
Loading…
Reference in New Issue