add pipeline async run

This commit is contained in:
jyong 2025-08-26 15:20:40 +08:00
parent 0f3ca1d8f4
commit c77bdd1fb3
3 changed files with 121 additions and 1 deletions

View File

@ -1568,3 +1568,25 @@ def transform_datasource_credentials():
click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")
)
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))
@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.")
@click.option(
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
)
@click.option(
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
)
@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100)
def install_rag_pipeline_plugins(input_file, output_file, workers):
"""
Install rag pipeline plugins
"""
click.echo(click.style("Installing rag pipeline plugins", fg="yellow"))
plugin_migration = PluginMigration()
plugin_migration.install_rag_pipeline_plugins(
input_file,
output_file,
workers,
)
click.echo(click.style("Installing rag pipeline plugins successfully", fg="green"))

View File

@ -420,6 +420,101 @@ class PluginMigration:
)
)
@classmethod
def install_rag_pipeline_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
"""
Install rag pipeline plugins.
"""
manager = PluginInstaller()
plugins = cls.extract_unique_plugins(extracted_plugins)
plugin_install_failed = []
# use a fake tenant id to install all the plugins
fake_tenant_id = uuid4().hex
logger.info("Installing %s plugin instances for fake tenant %s", len(plugins["plugins"]), fake_tenant_id)
thread_pool = ThreadPoolExecutor(max_workers=workers)
response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
if response.get("failed"):
plugin_install_failed.extend(response.get("failed", []))
def install(tenant_id: str, plugin_ids: dict[str, str], total_success_tenant: int, total_failed_tenant: int) -> None:
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
try:
# fetch plugin already installed
installed_plugins = manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
# at most 64 plugins one batch
for i in range(0, len(plugin_ids), 64):
batch_plugin_ids = list(plugin_ids.keys())[i : i + 64]
batch_plugin_identifiers = [
plugin_ids[plugin_id]
for plugin_id in batch_plugin_ids
if plugin_id not in installed_plugins_ids and plugin_id in plugin_ids
]
manager.install_from_identifiers(
tenant_id,
batch_plugin_identifiers,
PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": identifier,
}
for identifier in batch_plugin_identifiers
],
)
total_success_tenant += 1
except Exception:
logger.exception("Failed to install plugins for tenant %s", tenant_id)
total_failed_tenant += 1
page = 1
total_success_tenant = 0
total_failed_tenant = 0
while True:
# paginate
tenants = db.paginate(db.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
if tenants.items is None or len(tenants.items) == 0:
break
for tenant in tenants:
tenant_id = tenant.id
# get plugin unique identifier
thread_pool.submit(
install,
tenant_id,
plugins.get("plugins", {}),
total_success_tenant,
total_failed_tenant,
)
page += 1
thread_pool.shutdown(wait=True)
# uninstall all the plugins for fake tenant
try:
installation = manager.list_plugins(fake_tenant_id)
while installation:
for plugin in installation:
manager.uninstall(fake_tenant_id, plugin.installation_id)
installation = manager.list_plugins(fake_tenant_id)
except Exception:
logger.exception("Failed to get installation for tenant %s", fake_tenant_id)
Path(output_file).write_text(
json.dumps(
{
"total_success_tenant": total_success_tenant,
"total_failed_tenant": total_failed_tenant,
"plugin_install_failed": plugin_install_failed,
}
)
)
@classmethod
def handle_plugin_instance_install(
cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]

View File

@ -300,9 +300,12 @@ class RagPipelineDslService:
):
raise ValueError("Chunk structure is not compatible with the published pipeline")
if not dataset:
datasets = db.session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
names = [dataset.name for dataset in datasets]
generate_name = generate_incremental_name(names, name)
dataset = Dataset(
tenant_id=account.current_tenant_id,
name=name + datetime.now(UTC).strftime("%Y%m%d%H%M%S%f"),
name=generate_name,
description=description,
icon_info={
"type": icon_type,