mirror of https://github.com/langgenius/dify.git
Merge branch 'feat/rag-2' into feat/workflow-draft-var-optimize
This commit is contained in:
commit
e9e7d4f3cd
|
|
@ -8,6 +8,7 @@ on:
|
|||
- "deploy/enterprise"
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "deploy/rag-dev"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ on:
|
|||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/dev"
|
||||
- "deploy/rag-dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
|
|
@ -12,12 +12,13 @@ jobs:
|
|||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success'
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
|
|
|
|||
|
|
@ -460,6 +460,16 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
|||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
|
||||
# Seconds of idle time before scaling down workers (default: 5.0)
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
|
||||
|
||||
# Workflow storage configuration
|
||||
# Options: rdbms, hybrid
|
||||
# rdbms: Use only the relational database (default)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,122 @@
|
|||
[importlinter]
|
||||
root_packages =
|
||||
core
|
||||
configs
|
||||
controllers
|
||||
models
|
||||
tasks
|
||||
services
|
||||
|
||||
[importlinter:contract:workflow]
|
||||
name = Workflow
|
||||
type=layers
|
||||
layers =
|
||||
graph_engine
|
||||
graph_events
|
||||
graph
|
||||
nodes
|
||||
node_events
|
||||
entities
|
||||
containers =
|
||||
core.workflow
|
||||
ignore_imports =
|
||||
core.workflow.nodes.base.node -> core.workflow.graph_events
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
core.workflow.nodes.node_factory -> core.workflow.graph
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
response_coordinator
|
||||
output_registry
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:worker]
|
||||
name = Worker
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
worker
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:graph-engine-architecture]
|
||||
name = Graph Engine Architecture
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
orchestration
|
||||
command_processing
|
||||
event_management
|
||||
error_handling
|
||||
graph_traversal
|
||||
state_management
|
||||
worker_management
|
||||
domain
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:domain-isolation]
|
||||
name = Domain Model Isolation
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow.graph_engine.domain
|
||||
forbidden_modules =
|
||||
core.workflow.graph_engine.worker_management
|
||||
core.workflow.graph_engine.command_channels
|
||||
core.workflow.graph_engine.layers
|
||||
core.workflow.graph_engine.protocols
|
||||
|
||||
[importlinter:contract:state-management-layers]
|
||||
name = State Management Layers
|
||||
type = layers
|
||||
layers =
|
||||
execution_tracker
|
||||
node_state_manager
|
||||
edge_state_manager
|
||||
containers =
|
||||
core.workflow.graph_engine.state_management
|
||||
|
||||
[importlinter:contract:worker-management-layers]
|
||||
name = Worker Management Layers
|
||||
type = layers
|
||||
layers =
|
||||
worker_pool
|
||||
worker_factory
|
||||
dynamic_scaler
|
||||
activity_tracker
|
||||
containers =
|
||||
core.workflow.graph_engine.worker_management
|
||||
|
||||
[importlinter:contract:error-handling-strategies]
|
||||
name = Error Handling Strategies
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.error_handling.abort_strategy
|
||||
core.workflow.graph_engine.error_handling.retry_strategy
|
||||
core.workflow.graph_engine.error_handling.fail_branch_strategy
|
||||
core.workflow.graph_engine.error_handling.default_value_strategy
|
||||
|
||||
[importlinter:contract:graph-traversal-components]
|
||||
name = Graph Traversal Components
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.graph_traversal.node_readiness
|
||||
core.workflow.graph_engine.graph_traversal.skip_propagator
|
||||
|
||||
[importlinter:contract:command-channels]
|
||||
name = Command Channels Independence
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.command_channels.in_memory_channel
|
||||
core.workflow.graph_engine.command_channels.redis_channel
|
||||
27
api/app.py
27
api/app.py
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
|
||||
|
|
@ -17,20 +16,20 @@ else:
|
|||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
from gevent import monkey
|
||||
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
# from gevent import monkey
|
||||
#
|
||||
# # gevent
|
||||
# monkey.patch_all()
|
||||
#
|
||||
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
#
|
||||
# # grpc gevent
|
||||
# grpc_gevent.init_gevent()
|
||||
|
||||
# gevent
|
||||
monkey.patch_all()
|
||||
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
|
||||
import psycogreen.gevent # type: ignore
|
||||
|
||||
psycogreen.gevent.patch_psycopg()
|
||||
# import psycogreen.gevent # type: ignore
|
||||
#
|
||||
# psycogreen.gevent.patch_psycopg()
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
|
|
|
|||
255
api/commands.py
255
api/commands.py
|
|
@ -13,11 +13,14 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -30,7 +33,10 @@ from models import Tenant
|
|||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
|
|
@ -1354,3 +1360,250 @@ def cleanup_orphaned_draft_variables(
|
|||
continue
|
||||
|
||||
logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps)
|
||||
|
||||
|
||||
@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_datasource_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup datasource oauth client
|
||||
"""
|
||||
provider_id = DatasourceProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
|
||||
deleted_count = (
|
||||
db.session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow"))
|
||||
oauth_client = DatasourceOauthParamConfig(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
system_credentials=client_params_dict,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"provider: {provider_name}", fg="green"))
|
||||
click.echo(click.style(f"plugin_id: {plugin_id}", fg="green"))
|
||||
click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green"))
|
||||
click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
||||
def transform_datasource_credentials():
|
||||
"""
|
||||
Transform datasource credentials
|
||||
"""
|
||||
try:
|
||||
installer_manager = PluginInstaller()
|
||||
plugin_migration = PluginMigration()
|
||||
|
||||
notion_plugin_id = "langgenius/notion_datasource"
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id)
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id)
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id)
|
||||
oauth_credential_type = CredentialType.OAUTH2
|
||||
api_key_credential_type = CredentialType.API_KEY
|
||||
|
||||
# deal notion credentials
|
||||
deal_notion_count = 0
|
||||
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
|
||||
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
|
||||
for credential in notion_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
if tenant_id not in notion_credentials_tenant_mapping:
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in notion_credentials_tenant_mapping.items():
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
installer_manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
[notion_plugin_unique_identifier],
|
||||
PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": notion_plugin_unique_identifier,
|
||||
}
|
||||
],
|
||||
)
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = credential.access_token
|
||||
# notion info
|
||||
notion_info = credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
|
||||
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for credential in firecrawl_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
if tenant_id not in firecrawl_credentials_tenant_mapping:
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
installer_manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
[firecrawl_plugin_unique_identifier],
|
||||
PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": firecrawl_plugin_unique_identifier,
|
||||
}
|
||||
],
|
||||
)
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
api_key = credential.credentials.get("config", {}).get("api_key")
|
||||
base_url = credential.credentials.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jina").all()
|
||||
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for credential in jina_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
if tenant_id not in jina_credentials_tenant_mapping:
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in jina_credentials_tenant_mapping.items():
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
installer_manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
[jina_plugin_unique_identifier],
|
||||
PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": jina_plugin_unique_identifier,
|
||||
}
|
||||
],
|
||||
)
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
api_key = credential.credentials.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green"))
|
||||
click.echo(
|
||||
click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")
|
||||
)
|
||||
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))
|
||||
|
||||
|
||||
@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.")
|
||||
@click.option(
|
||||
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||
)
|
||||
@click.option(
|
||||
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
|
||||
)
|
||||
@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100)
|
||||
def install_rag_pipeline_plugins(input_file, output_file, workers):
|
||||
"""
|
||||
Install rag pipeline plugins
|
||||
"""
|
||||
click.echo(click.style("Installing rag pipeline plugins", fg="yellow"))
|
||||
plugin_migration = PluginMigration()
|
||||
plugin_migration.install_rag_pipeline_plugins(
|
||||
input_file,
|
||||
output_file,
|
||||
workers,
|
||||
)
|
||||
click.echo(click.style("Installing rag pipeline plugins successfully", fg="green"))
|
||||
|
|
|
|||
|
|
@ -545,6 +545,28 @@ class WorkflowConfig(BaseSettings):
|
|||
default=200 * 1024,
|
||||
)
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
default=1,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||
description="Maximum number of workers per GraphEngine instance",
|
||||
default=10,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field(
|
||||
description="Queue depth threshold that triggers worker scale up",
|
||||
default=3,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field(
|
||||
description="Seconds of idle time before scaling down workers",
|
||||
default=5.0,
|
||||
ge=0.1,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionConfig(BaseSettings):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -222,11 +222,28 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
|||
)
|
||||
|
||||
|
||||
class HostedFetchPipelineTemplateConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching pipeline templates
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
|
||||
default="database",
|
||||
)
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||
description="Domain for fetching remote pipeline templates",
|
||||
default="https://tmpl.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
class HostedServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
HostedAnthropicConfig,
|
||||
HostedAzureOpenAiConfig,
|
||||
HostedFetchAppTemplateConfig,
|
||||
HostedFetchPipelineTemplateConfig,
|
||||
HostedMinmaxConfig,
|
||||
HostedOpenAiConfig,
|
||||
HostedSparkConfig,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from threading import Lock
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
|
@ -33,3 +34,11 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
|
|||
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_schemas")
|
||||
)
|
||||
|
||||
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
|
||||
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
|
||||
)
|
||||
|
||||
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("datasource_plugin_providers_lock")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm"
|
|||
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, ping, setup, version
|
||||
from . import admin, apikey, extension, feature, ping, setup, spec, version
|
||||
|
||||
# Import app controllers
|
||||
from .app import (
|
||||
|
|
@ -86,6 +86,15 @@ from .datasets import (
|
|||
metadata,
|
||||
website,
|
||||
)
|
||||
from .datasets.rag_pipeline import (
|
||||
datasource_auth,
|
||||
datasource_content_preview,
|
||||
rag_pipeline,
|
||||
rag_pipeline_datasets,
|
||||
rag_pipeline_draft_variable,
|
||||
rag_pipeline_import,
|
||||
rag_pipeline_workflow,
|
||||
)
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
|
|
|
|||
|
|
@ -16,7 +16,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
|||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class RuleGenerateApi(Resource):
|
||||
|
|
@ -135,9 +138,6 @@ class InstructionGenerateApi(Resource):
|
|||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
from models import App, db
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||
if not app:
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
|
|
@ -413,7 +414,12 @@ class WorkflowTaskStopApi(Resource):
|
|||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
|||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
|
|
|
|||
|
|
@ -18,10 +18,11 @@ from core.variables.segment_group import SegmentGroup
|
|||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode, db
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
|
|
@ -9,6 +11,8 @@ from werkzeug.exceptions import NotFound
|
|||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
|
|
@ -18,6 +22,7 @@ from libs.datetime_utils import naive_utc_now
|
|||
from libs.login import login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
|
|
@ -112,6 +117,18 @@ class DataSourceNotionListApi(Resource):
|
|||
@marshal_with(integrate_notion_info_list_fields)
|
||||
def get(self):
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
if not credential:
|
||||
raise NotFound("Credential not found.")
|
||||
exist_page_ids = []
|
||||
with Session(db.engine) as session:
|
||||
# import notion in the exist dataset
|
||||
|
|
@ -135,31 +152,49 @@ class DataSourceNotionListApi(Resource):
|
|||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
data_source_bindings = session.scalars(
|
||||
select(DataSourceOauthBinding).filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id="langgenius/notion_datasource/notion_datasource",
|
||||
datasource_name="notion_datasource",
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
if credential:
|
||||
datasource_runtime.runtime.credentials = credential
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_pages(
|
||||
user_id=current_user.id,
|
||||
datasource_parameters={},
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {"notion_info": []}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info["pages"]
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page["page_id"] in exist_page_ids:
|
||||
page["is_bound"] = True
|
||||
else:
|
||||
page["is_bound"] = False
|
||||
pre_import_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {"notion_info": pre_import_info_list}, 200
|
||||
)
|
||||
try:
|
||||
pages = []
|
||||
workspace_info = {}
|
||||
for message in online_document_result:
|
||||
result = message.result
|
||||
for info in result:
|
||||
workspace_info = {
|
||||
"workspace_id": info.workspace_id,
|
||||
"workspace_name": info.workspace_name,
|
||||
"workspace_icon": info.workspace_icon,
|
||||
}
|
||||
for page in info.pages:
|
||||
page_info = {
|
||||
"page_id": page.page_id,
|
||||
"page_name": page.page_name,
|
||||
"type": page.type,
|
||||
"parent_id": page.parent_id,
|
||||
"is_bound": page.page_id in exist_page_ids,
|
||||
"page_icon": page.page_icon,
|
||||
}
|
||||
pages.append(page_info)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return {"notion_info": {**workspace_info, "pages": pages}}, 200
|
||||
|
||||
|
||||
class DataSourceNotionApi(Resource):
|
||||
|
|
@ -167,27 +202,25 @@ class DataSourceNotionApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).where(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if not data_source_binding:
|
||||
raise NotFound("Data source binding not found.")
|
||||
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
|
||||
|
|
@ -212,10 +245,12 @@ class DataSourceNotionApi(Resource):
|
|||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ from controllers.console.wraps import (
|
|||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
|
|
@ -31,6 +30,7 @@ from fields.document_fields import document_status_fields
|
|||
from libs.login import login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
|
|
@ -279,6 +279,15 @@ class DatasetApi(Resource):
|
|||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid icon info.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
|
|
@ -429,10 +438,12 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
notion_info_list = args["info_list"]["notion_info_list"]
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from typing import Literal, cast
|
||||
|
|
@ -51,6 +52,7 @@ from fields.document_fields import (
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
|
||||
|
|
@ -496,6 +498,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
|
|
@ -649,7 +652,7 @@ class DocumentApi(DocumentResource):
|
|||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
|
|
@ -1012,6 +1015,41 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
|||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter_by(document_id=document_id)
|
||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not log:
|
||||
return {
|
||||
"datasource_info": None,
|
||||
"datasource_type": None,
|
||||
"input_data": None,
|
||||
"datasource_node_id": None,
|
||||
}, 200
|
||||
return {
|
||||
"datasource_info": json.loads(log.datasource_info),
|
||||
"datasource_type": log.datasource_type,
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
|
||||
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DatasetInitApi, "/datasets/init")
|
||||
|
|
@ -1033,3 +1071,6 @@ api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
|
|||
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||
|
||||
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
api.add_resource(
|
||||
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -71,3 +71,9 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
|
|||
error_code = "child_chunk_delete_index_error"
|
||||
description = "Delete child chunk index failed: {message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class PipelineNotFoundError(BaseHTTPException):
|
||||
error_code = "pipeline_not_found"
|
||||
description = "Pipeline not found."
|
||||
code = 404
|
||||
|
|
|
|||
|
|
@ -0,0 +1,362 @@
|
|||
from fastapi.encoders import jsonable_encoder
|
||||
from flask import make_response, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import login_required
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
credential_id = request.args.get("credential_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
oauth_config = DatasourceProviderService().get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
if not oauth_config:
|
||||
raise ValueError(f"No OAuth Client Config for {provider_id}")
|
||||
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=current_user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_config,
|
||||
)
|
||||
response = make_response(jsonable_encoder(authorization_url_response))
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class DatasourceOAuthCallback(Resource):
|
||||
@setup_required
|
||||
def get(self, provider_id: str):
|
||||
context_id = request.cookies.get("context_id") or request.args.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
oauth_client_params = datasource_provider_service.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
if not oauth_client_params:
|
||||
raise NotFound()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
oauth_handler = OAuthHandler()
|
||||
oauth_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
credential_id = context.get("credential_id")
|
||||
if credential_id:
|
||||
datasource_provider_service.reauthorize_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
avatar_url=oauth_response.metadata.get("avatar_url") or None,
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
expire_at=oauth_response.expires_at,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
credential_id=context.get("credential_id"),
|
||||
)
|
||||
else:
|
||||
datasource_provider_service.add_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
avatar_url=oauth_response.metadata.get("avatar_url") or None,
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
expire_at=oauth_response.expires_at,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
|
||||
)
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
try:
|
||||
datasource_provider_service.add_datasource_api_key_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
credentials=args.get("credentials", {}),
|
||||
name=args.get("name", None),
|
||||
)
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
class DatasourceAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
class DatasourceHardCodeAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
client_params=args.get("client_params", {}),
|
||||
enabled=args.get("enable_oauth_custom_client", False),
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
credential_id=args["id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
name=args["name"],
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DatasourcePluginOAuthAuthorizationUrl,
|
||||
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceOAuthCallback,
|
||||
"/oauth/plugin/<path:provider_id>/datasource/callback",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceAuth,
|
||||
"/auth/plugin/datasource/<path:provider_id>",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthUpdateApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/update",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthDeleteApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/delete",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthListApi,
|
||||
"/auth/plugin/datasource/list",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceHardCodeAuthListApi,
|
||||
"/auth/plugin/datasource/default-list",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthOauthCustomClient,
|
||||
"/auth/plugin/datasource/<path:provider_id>/custom-client",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthDefaultApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/default",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceUpdateProviderNameApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/update-name",
|
||||
)
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
from flask_restx import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run datasource content preview
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
preview_content = rag_pipeline_service.run_datasource_node_preview(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
is_published=True,
|
||||
credential_id=args.get("credential_id"),
|
||||
)
|
||||
return preview_content, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DataSourceContentPreviewApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
|
||||
)
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
knowledge_pipeline_publish_enabled,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
# get pipeline templates
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||
return pipeline_templates, 200
|
||||
|
||||
|
||||
class PipelineTemplateDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, template_id: str):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, template_id: str):
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
return 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, template_id: str):
|
||||
with Session(db.engine) as session:
|
||||
template = (
|
||||
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
|
||||
return {"data": template.yaml_content}, 200
|
||||
|
||||
|
||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@knowledge_pipeline_publish_enabled
|
||||
def post(self, pipeline_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
"/rag/pipeline/templates",
|
||||
)
|
||||
api.add_resource(
|
||||
PipelineTemplateDetailApi,
|
||||
"/rag/pipeline/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
CustomizedPipelineTemplateApi,
|
||||
"/rag/pipeline/customized/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/customized/publish",
|
||||
)
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restx import Resource, marshal, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
)
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument(
|
||||
"yaml_content",
|
||||
type=str,
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="yaml_content is required.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(
|
||||
name="",
|
||||
description="",
|
||||
icon_info=IconInfo(
|
||||
icon="📙",
|
||||
icon_background="#FFF4ED",
|
||||
icon_type="emoji",
|
||||
),
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
partial_member_list=None,
|
||||
yaml_content=args["yaml_content"],
|
||||
)
|
||||
try:
|
||||
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
current_user.current_tenant_id,
|
||||
import_info["dataset_id"],
|
||||
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return import_info, 201
|
||||
|
||||
|
||||
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
||||
name="",
|
||||
description="",
|
||||
icon_info=IconInfo(
|
||||
icon="📙",
|
||||
icon_background="#FFF4ED",
|
||||
icon_type="emoji",
|
||||
),
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
partial_member_list=None,
|
||||
),
|
||||
)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
|
||||
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")
|
||||
|
|
@ -0,0 +1,417 @@
|
|||
import logging
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||
value = variable.get_value()
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
# Refresh the url signature before returning it to client.
|
||||
if isinstance(value, FileSegment):
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
return parser
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
value=fields.Raw(attribute=_serialize_var_value),
|
||||
)
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda _: "env"),
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
|
||||
}
|
||||
|
||||
|
||||
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||
return var_list.variables
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
|
||||
"total": fields.Raw(),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
|
||||
def _api_prerequisite(f):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
|
||||
- Dify has been property setup.
|
||||
- The request user has logged in and initialized.
|
||||
- The requested app is a workflow or a chat flow.
|
||||
- The request user has the edit permission for the app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args, **kwargs):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RagPipelineVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_exist = rag_pipeline_service.is_workflow_exist(pipeline=pipeline)
|
||||
if not workflow_exist:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=pipeline.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
draft_var_srv.delete_workflow_variables(pipeline.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
if node_id in [
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
]:
|
||||
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
|
||||
# with specific `node_id` in database, we still want to make the API separated. By disallowing
|
||||
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
|
||||
# we mitigate the risk that user of the API depending on the implementation detail of the API.
|
||||
#
|
||||
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
|
||||
|
||||
raise InvalidArgumentError(
|
||||
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(pipeline.id, node_id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class RagPipelineVariableApi(Resource):
|
||||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def get(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def patch(self, pipeline: Pipeline, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "local_file",
|
||||
# "url": "",
|
||||
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
|
||||
# }
|
||||
#
|
||||
# Remote File:
|
||||
#
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "remote_url",
|
||||
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
# Parse 'value' field as-is to maintain its original data structure
|
||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
args = parser.parse_args(strict=True)
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class RagPipelineVariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, pipeline_id={pipeline.id}",
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
else:
|
||||
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
|
||||
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(pipeline.id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
class RagPipelineSystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, pipeline: Pipeline):
|
||||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
if workflow is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
env_vars = workflow.environment_variables
|
||||
env_vars_list = []
|
||||
for v in env_vars:
|
||||
env_vars_list.append(
|
||||
{
|
||||
"id": v.id,
|
||||
"type": "env",
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.value,
|
||||
"value": v.value,
|
||||
# Do not track edited for env vars.
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"editable": True,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineNodeVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineEnvironmentVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
|
||||
)
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
from typing import cast
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class RagPipelineImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("mode", type=str, required=True, location="json")
|
||||
parser.add_argument("yaml_content", type=str, location="json")
|
||||
parser.add_argument("yaml_url", type=str, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("pipeline_id", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Import app
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
yaml_content=args.get("yaml_content"),
|
||||
yaml_url=args.get("yaml_url"),
|
||||
pipeline_id=args.get("pipeline_id"),
|
||||
dataset_name=args.get("name"),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Confirm import
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
def get(self, pipeline: Pipeline):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=bool, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
|
||||
|
||||
return {"data": result}, 200
|
||||
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
RagPipelineImportApi,
|
||||
"/rag/pipelines/imports",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportConfirmApi,
|
||||
"/rag/pipelines/imports/<string:import_id>/confirm",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineExportApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/exports",
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,47 @@
|
|||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
def get_rag_pipeline(
|
||||
view: Optional[Callable] = None,
|
||||
):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
if not kwargs.get("pipeline_id"):
|
||||
raise ValueError("missing pipeline_id in path parameters")
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user is not an account")
|
||||
|
||||
pipeline_id = kwargs.get("pipeline_id")
|
||||
pipeline_id = str(pipeline_id)
|
||||
|
||||
del kwargs["pipeline_id"]
|
||||
|
||||
pipeline = (
|
||||
db.session.query(Pipeline)
|
||||
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not pipeline:
|
||||
raise PipelineNotFoundError()
|
||||
|
||||
kwargs["pipeline"] = pipeline
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
|
@ -20,6 +20,7 @@ from core.errors.error import (
|
|||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
from models.model import AppMode, InstalledApp
|
||||
|
|
@ -78,6 +79,11 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
|||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,35 @@
|
|||
import logging
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.schemas.schema_manager import SchemaManager
|
||||
from libs.login import login_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpecSchemaDefinitionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""
|
||||
Get system JSON Schema definitions specification
|
||||
Used for frontend component type mapping
|
||||
"""
|
||||
try:
|
||||
schema_manager = SchemaManager()
|
||||
schema_definitions = schema_manager.get_all_schema_definitions()
|
||||
return schema_definitions, 200
|
||||
except Exception:
|
||||
logger.exception("Failed to get schema definitions from local registry")
|
||||
# Return empty array as fallback
|
||||
return [], 200
|
||||
|
||||
|
||||
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")
|
||||
|
|
@ -21,11 +21,11 @@ from core.mcp.auth.auth_provider import OAuthClientProvider
|
|||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
|
|
|||
|
|
@ -261,3 +261,14 @@ def is_allow_transfer_owner(view):
|
|||
abort(403)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def knowledge_pipeline_publish_enabled(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if features.knowledge_pipeline.publish_enabled:
|
||||
return view(*args, **kwargs)
|
||||
abort(403)
|
||||
|
||||
return decorated
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from controllers.common.errors import UnsupportedFileTypeError
|
|||
from controllers.files import files_ns
|
||||
from core.tools.signature import verify_tool_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models import db as global_db
|
||||
from extensions.ext_database import db as global_db
|
||||
|
||||
|
||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ from core.errors.error import (
|
|||
)
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
|
|
@ -262,7 +263,12 @@ class WorkflowTaskStopApi(Resource):
|
|||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ from controllers.service_api.wraps import (
|
|||
validate_dataset_token,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
|
|
|
|||
|
|
@ -133,6 +133,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
# validate args
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from core.errors.error import (
|
|||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
|
@ -110,7 +111,12 @@ class WorkflowTaskStopApi(WebApiResource):
|
|||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,9 @@ class BaseAgentRunner(AppRunner):
|
|||
tenant_id=tenant_id,
|
||||
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
|
||||
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
||||
return_resource=app_config.additional_features.show_retrieve_source,
|
||||
return_resource=(
|
||||
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
|
||||
),
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from typing import Any
|
|||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
class ModelConfigManager:
|
||||
|
|
|
|||
|
|
@ -114,9 +114,9 @@ class VariableEntity(BaseModel):
|
|||
hide: bool = False
|
||||
max_length: Optional[int] = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
|
||||
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -129,6 +129,16 @@ class VariableEntity(BaseModel):
|
|||
return v or []
|
||||
|
||||
|
||||
class RagPipelineVariableEntity(VariableEntity):
|
||||
"""
|
||||
Rag Pipeline Variable Entity.
|
||||
"""
|
||||
|
||||
tooltips: Optional[str] = None
|
||||
placeholder: Optional[str] = None
|
||||
belong_to_node_id: str
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
External Data Variable Entity.
|
||||
|
|
@ -288,7 +298,7 @@ class AppConfig(BaseModel):
|
|||
tenant_id: str
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
additional_features: AppAdditionalFeatures
|
||||
additional_features: Optional[AppAdditionalFeatures] = None
|
||||
variables: list[VariableEntity] = []
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from core.app.app_config.entities import VariableEntity
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
|
|
@ -20,3 +22,44 @@ class WorkflowVariablesConfigManager:
|
|||
variables.append(VariableEntity.model_validate(variable))
|
||||
|
||||
return variables
|
||||
|
||||
@classmethod
|
||||
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
|
||||
"""
|
||||
Convert workflow start variables to variables
|
||||
|
||||
:param workflow: workflow instance
|
||||
"""
|
||||
variables = []
|
||||
|
||||
# get second step node
|
||||
rag_pipeline_variables = workflow.rag_pipeline_variables
|
||||
if not rag_pipeline_variables:
|
||||
return []
|
||||
variables_map = {item["variable"]: item for item in rag_pipeline_variables}
|
||||
|
||||
# get datasource node data
|
||||
datasource_node_data = None
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
for datasource_node in datasource_nodes:
|
||||
if datasource_node.get("id") == start_node_id:
|
||||
datasource_node_data = datasource_node.get("data", {})
|
||||
break
|
||||
if datasource_node_data:
|
||||
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
|
||||
|
||||
for key, value in datasource_parameters.items():
|
||||
if value.get("value") and isinstance(value.get("value"), str):
|
||||
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
|
||||
match = re.match(pattern, value["value"])
|
||||
if match:
|
||||
full_path = match.group(1)
|
||||
last_part = full_path.split(".")[-1]
|
||||
variables_map.pop(last_part)
|
||||
all_second_step_variables = list(variables_map.values())
|
||||
|
||||
for item in all_second_step_variables:
|
||||
if item.get("belong_to_node_id") == start_node_id or item.get("belong_to_node_id") == "shared":
|
||||
variables.append(RagPipelineVariableEntity.model_validate(item))
|
||||
|
||||
return variables
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# init application generate entity
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
|
|
@ -23,16 +23,17 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
|
|||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.workflow import ConversationVariable, WorkflowType
|
||||
from models.workflow import ConversationVariable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -76,23 +77,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
|
@ -144,16 +151,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
# RUN WORKFLOW
|
||||
# Create Redis command channel for this workflow execution
|
||||
task_id = self.application_generate_entity.task_id
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
command_channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
||||
graph=graph,
|
||||
graph_config=self._workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
|
|
@ -164,12 +182,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks,
|
||||
)
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
|
|
|||
|
|
@ -31,14 +31,9 @@ from core.app.entities.queue_entities import (
|
|||
QueueMessageReplaceEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
|
|
@ -65,8 +60,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
|||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
|
|
@ -395,9 +390,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[
|
||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
||||
],
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
|
|
@ -442,32 +435,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
|
||||
def _handle_parallel_branch_started_event(
|
||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_start_resp
|
||||
|
||||
def _handle_parallel_branch_finished_events(
|
||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch finished events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_finish_resp
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
self, event: QueueIterationStartEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
|
|
@ -759,8 +726,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||
# Parallel branch events
|
||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
||||
# Iteration events
|
||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||
|
|
@ -808,8 +773,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
event,
|
||||
(
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
),
|
||||
):
|
||||
|
|
@ -822,17 +785,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
)
|
||||
return
|
||||
|
||||
# Handle parallel branch finished events with isinstance check
|
||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
||||
yield from self._handle_parallel_branch_finished_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# For unhandled events, we continue (original behavior)
|
||||
return
|
||||
|
||||
|
|
@ -856,11 +808,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
graph_runtime_state = event.graph_runtime_state
|
||||
yield from self._handle_workflow_started_event(event)
|
||||
|
||||
case QueueTextChunkEvent():
|
||||
yield from self._handle_text_chunk_event(
|
||||
event, tts_publisher=tts_publisher, queue_message=queue_message
|
||||
)
|
||||
|
||||
case QueueErrorEvent():
|
||||
yield from self._handle_error_event(event)
|
||||
break
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
|||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaver,
|
||||
DraftVariableSaverFactory,
|
||||
|
|
|
|||
|
|
@ -126,6 +126,21 @@ class AppQueueManager:
|
|||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
@classmethod
|
||||
def set_stop_flag_no_user_check(cls, task_id: str) -> None:
|
||||
"""
|
||||
Set task stop flag without user permission check.
|
||||
This method allows stopping workflows without user context.
|
||||
|
||||
:param task_id: The task ID to stop
|
||||
:return:
|
||||
"""
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
def _is_stopped(self) -> bool:
|
||||
"""
|
||||
Check if task is stopped
|
||||
|
|
|
|||
|
|
@ -162,7 +162,9 @@ class ChatAppRunner(AppRunner):
|
|||
config=app_config.dataset,
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
show_retrieve_source=(
|
||||
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
|
||||
),
|
||||
hit_callback=hit_callback,
|
||||
memory=memory,
|
||||
message_id=message.id,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -16,14 +16,9 @@ from core.app.entities.queue_entities import (
|
|||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
|
|
@ -36,18 +31,17 @@ from core.app.entities.task_entities import (
|
|||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import (
|
||||
|
|
@ -174,23 +168,25 @@ class WorkflowResponseConverter:
|
|||
|
||||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
node_data = cast(ToolNodeData, event.node_data)
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id,
|
||||
provider_type=ToolProviderType(event.provider_type),
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
elif event.node_type == NodeType.DATASOURCE:
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
)
|
||||
response.data.extras["icon"] = provider_entity.declaration.identity.icon
|
||||
|
||||
return response
|
||||
|
||||
def workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
event: QueueNodeSucceededEvent
|
||||
| QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
|
|
@ -227,9 +223,6 @@ class WorkflowResponseConverter:
|
|||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
),
|
||||
|
|
@ -284,50 +277,6 @@ class WorkflowResponseConverter:
|
|||
),
|
||||
)
|
||||
|
||||
def workflow_parallel_branch_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_execution_id: str,
|
||||
event: QueueParallelBranchRunStartedEvent,
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
return ParallelBranchStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
data=ParallelBranchStartStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_execution_id: str,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
return ParallelBranchFinishedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
data=ParallelBranchFinishedStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_iteration_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -343,14 +292,12 @@ class WorkflowResponseConverter:
|
|||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=new_inputs,
|
||||
inputs_truncated=truncated,
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -368,17 +315,10 @@ class WorkflowResponseConverter:
|
|||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
index=event.index,
|
||||
# The `pre_iteration_output` field is not utilized by the frontend.
|
||||
# Previously, it was assigned the value of `event.output`.
|
||||
pre_iteration_output={},
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -402,7 +342,7 @@ class WorkflowResponseConverter:
|
|||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
outputs=new_outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
created_at=int(time.time()),
|
||||
|
|
@ -418,8 +358,6 @@ class WorkflowResponseConverter:
|
|||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -434,7 +372,7 @@ class WorkflowResponseConverter:
|
|||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=new_inputs,
|
||||
|
|
@ -459,7 +397,7 @@ class WorkflowResponseConverter:
|
|||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
index=event.index,
|
||||
# The `pre_loop_output` field is not utilized by the frontend.
|
||||
# Previously, it was assigned the value of `event.output`.
|
||||
|
|
@ -469,7 +407,6 @@ class WorkflowResponseConverter:
|
|||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -492,7 +429,7 @@ class WorkflowResponseConverter:
|
|||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
outputs=new_outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
created_at=int(time.time()),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,95 @@
|
|||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return dict(blocking_response.to_dict())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
for chunk in stream_response:
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
for chunk in stream_response:
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
|
||||
from models.dataset import Pipeline
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class PipelineConfig(WorkflowUIBasedAppConfig):
|
||||
"""
|
||||
Pipeline Config Entity.
|
||||
"""
|
||||
|
||||
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
|
||||
pass
|
||||
|
||||
|
||||
class PipelineConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow, start_node_id: str) -> PipelineConfig:
|
||||
pipeline_config = PipelineConfig(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
workflow_id=workflow.id,
|
||||
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(
|
||||
workflow=workflow, start_node_id=start_node_id
|
||||
),
|
||||
)
|
||||
|
||||
return pipeline_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||
"""
|
||||
Validate for pipeline config
|
||||
|
||||
:param tenant_id: tenant id
|
||||
:param config: app model config args
|
||||
:param only_structure_validate: only validate the structure of the config
|
||||
"""
|
||||
related_config_keys = []
|
||||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# text_to_speech
|
||||
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
|
|
@ -0,0 +1,802 @@
|
|||
import contextvars
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Optional, Union, cast, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderType,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
)
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.dataset_service import DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||
# Add null check for dataset
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
start_node_id: str = args["start_node_id"]
|
||||
datasource_type: str = args["datasource_type"]
|
||||
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
|
||||
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
|
||||
)
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
|
||||
)
|
||||
documents = []
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
for datasource_info in datasource_info_list:
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document = self._build_document(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
built_in_field_enabled=dataset.built_in_field_enabled,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
created_from="rag-pipeline",
|
||||
position=position,
|
||||
account=user,
|
||||
batch=batch,
|
||||
document_form=dataset.chunk_structure,
|
||||
)
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
db.session.commit()
|
||||
|
||||
# run in child thread
|
||||
for i, datasource_info in enumerate(datasource_info_list):
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
document_id = None
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
document_id = documents[i].id
|
||||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||
document_id=document_id,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=json.dumps(datasource_info),
|
||||
datasource_node_id=start_node_id,
|
||||
input_data=inputs,
|
||||
pipeline_id=pipeline.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
db.session.add(document_pipeline_execution_log)
|
||||
db.session.commit()
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
dataset_id=dataset.id,
|
||||
start_node_id=start_node_id,
|
||||
batch=batch,
|
||||
document_id=document_id,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=pipeline_config.rag_pipeline_variables,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
)
|
||||
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context=contextvars.copy_context(),
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow.id,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
else:
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
pipeline_id=pipeline.id,
|
||||
user_id=user.id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
workflow_id=workflow.id,
|
||||
streaming=streaming,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
application_generate_entity=application_generate_entity.model_dump(),
|
||||
)
|
||||
# return batch, dataset, documents
|
||||
return {
|
||||
"batch": batch,
|
||||
"dataset": PipelineDataset(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
chunk_structure=dataset.chunk_structure,
|
||||
).model_dump(),
|
||||
"documents": [
|
||||
PipelineDocument(
|
||||
id=document.id,
|
||||
position=document.position,
|
||||
data_source_type=document.data_source_type,
|
||||
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
|
||||
name=document.name,
|
||||
indexing_status=document.indexing_status,
|
||||
error=document.error,
|
||||
enabled=document.enabled,
|
||||
).model_dump()
|
||||
for document in documents
|
||||
],
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
flask_app: Flask,
|
||||
context: contextvars.Context,
|
||||
pipeline: Pipeline,
|
||||
workflow_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param pipeline: Pipeline
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
# init queue manager
|
||||
workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||
queue_manager = PipelineQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"queue_manager": queue_manager,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
"variable_loader": variable_loader,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
draft_var_saver_factory = self._get_draft_var_saver_factory(
|
||||
invoke_from,
|
||||
)
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=streaming,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def single_iteration_generate(
|
||||
self,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param node_id: the node id
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param streaming: is streamed
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
|
||||
)
|
||||
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
# init application generate entity - use RagPipelineGenerateEntity instead
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=args.get("datasource_type", ""),
|
||||
datasource_info=args.get("datasource_info", {}),
|
||||
dataset_id=dataset.id,
|
||||
batch=args.get("batch", ""),
|
||||
document_id=args.get("document_id"),
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow.id,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
self,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param node_id: the node id
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param streaming: is streamed
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=args.get("datasource_type", ""),
|
||||
datasource_info=args.get("datasource_info", {}),
|
||||
batch=args.get("batch", ""),
|
||||
document_id=args.get("document_id"),
|
||||
dataset_id=dataset.id,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow.id,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = session.scalar(
|
||||
select(Workflow).where(
|
||||
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
Workflow.app_id == application_generate_entity.app_config.app_id,
|
||||
Workflow.id == application_generate_entity.app_config.workflow_id,
|
||||
)
|
||||
)
|
||||
if workflow is None:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
# Determine system_user_id based on invocation source
|
||||
is_external_api_call = application_generate_entity.invoke_from in {
|
||||
InvokeFrom.WEB_APP,
|
||||
InvokeFrom.SERVICE_API,
|
||||
}
|
||||
|
||||
if is_external_api_call:
|
||||
# For external API calls, use end user's session ID
|
||||
end_user = session.scalar(
|
||||
select(EndUser).where(EndUser.id == application_generate_entity.user_id)
|
||||
)
|
||||
system_user_id = end_user.session_id if end_user else ""
|
||||
else:
|
||||
# For internal calls, use the original user ID
|
||||
system_user_id = application_generate_entity.user_id
|
||||
# workflow app
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _handle_response(
|
||||
self,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param workflow: workflow
|
||||
:param queue_manager: queue manager
|
||||
:param user: account or end user
|
||||
:param stream: is stream
|
||||
:param workflow_node_execution_repository: optional repository for workflow node execution
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
)
|
||||
|
||||
try:
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(
|
||||
"Fails to process generate task pipeline, task_id: %r",
|
||||
application_generate_entity.task_id,
|
||||
)
|
||||
raise e
|
||||
|
||||
def _build_document(
|
||||
self,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
built_in_field_enabled: bool,
|
||||
datasource_type: str,
|
||||
datasource_info: Mapping[str, Any],
|
||||
created_from: str,
|
||||
position: int,
|
||||
account: Union[Account, EndUser],
|
||||
batch: str,
|
||||
document_form: str,
|
||||
):
|
||||
if datasource_type == "local_file":
|
||||
name = datasource_info["name"]
|
||||
elif datasource_type == "online_document":
|
||||
name = datasource_info["page"]["page_name"]
|
||||
elif datasource_type == "website_crawl":
|
||||
name = datasource_info["title"]
|
||||
elif datasource_type == "online_drive":
|
||||
name = datasource_info["key"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
document = Document(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=position,
|
||||
data_source_type=datasource_type,
|
||||
data_source_info=json.dumps(datasource_info),
|
||||
batch=batch,
|
||||
name=name,
|
||||
created_from=created_from,
|
||||
created_by=account.id,
|
||||
doc_form=document_form,
|
||||
)
|
||||
doc_metadata = {}
|
||||
if built_in_field_enabled:
|
||||
doc_metadata = {
|
||||
BuiltInField.document_name: name,
|
||||
BuiltInField.uploader: account.name,
|
||||
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
BuiltInField.source: datasource_type,
|
||||
}
|
||||
if doc_metadata:
|
||||
document.doc_metadata = doc_metadata
|
||||
return document
|
||||
|
||||
def _format_datasource_info_list(
|
||||
self,
|
||||
datasource_type: str,
|
||||
datasource_info_list: list[Mapping[str, Any]],
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
start_node_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
) -> list[Mapping[str, Any]]:
|
||||
"""
|
||||
Format datasource info list.
|
||||
"""
|
||||
if datasource_type == "online_drive":
|
||||
all_files = []
|
||||
datasource_node_data = None
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
for datasource_node in datasource_nodes:
|
||||
if datasource_node.get("id") == start_node_id:
|
||||
datasource_node_data = datasource_node.get("data", {})
|
||||
break
|
||||
if not datasource_node_data:
|
||||
raise ValueError("Datasource node data not found")
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
|
||||
datasource_name=datasource_node_data.get("datasource_name"),
|
||||
tenant_id=pipeline.tenant_id,
|
||||
datasource_type=DatasourceProviderType(datasource_type),
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
provider=datasource_node_data.get("provider_name"),
|
||||
plugin_id=datasource_node_data.get("plugin_id"),
|
||||
credential_id=datasource_node_data.get("credential_id"),
|
||||
)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
||||
|
||||
for datasource_info in datasource_info_list:
|
||||
if datasource_info.get("id") and datasource_info.get("type") == "folder":
|
||||
# get all files in the folder
|
||||
self._get_files_in_folder(
|
||||
datasource_runtime,
|
||||
datasource_info.get("id", ""),
|
||||
datasource_info.get("bucket", None),
|
||||
user.id,
|
||||
all_files,
|
||||
datasource_info,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
all_files.append(
|
||||
{
|
||||
"id": datasource_info.get("id", ""),
|
||||
"bucket": datasource_info.get("bucket", None),
|
||||
}
|
||||
)
|
||||
return all_files
|
||||
else:
|
||||
return datasource_info_list
|
||||
|
||||
def _get_files_in_folder(
|
||||
self,
|
||||
datasource_runtime: OnlineDriveDatasourcePlugin,
|
||||
prefix: str,
|
||||
bucket: Optional[str],
|
||||
user_id: str,
|
||||
all_files: list,
|
||||
datasource_info: Mapping[str, Any],
|
||||
next_page_parameters: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Get files in a folder.
|
||||
"""
|
||||
result_generator = datasource_runtime.online_drive_browse_files(
|
||||
user_id=user_id,
|
||||
request=OnlineDriveBrowseFilesRequest(
|
||||
bucket=bucket,
|
||||
prefix=prefix,
|
||||
max_keys=20,
|
||||
next_page_parameters=next_page_parameters,
|
||||
),
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
is_truncated = False
|
||||
last_file_key = None
|
||||
for result in result_generator:
|
||||
for files in result.result:
|
||||
for file in files.files:
|
||||
if file.type == "folder":
|
||||
self._get_files_in_folder(
|
||||
datasource_runtime,
|
||||
file.id,
|
||||
bucket,
|
||||
user_id,
|
||||
all_files,
|
||||
datasource_info,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
all_files.append(
|
||||
{
|
||||
"id": file.id,
|
||||
"bucket": bucket,
|
||||
}
|
||||
)
|
||||
is_truncated = files.is_truncated
|
||||
next_page_parameters = files.next_page_parameters
|
||||
|
||||
if is_truncated:
|
||||
self._get_files_in_folder(
|
||||
datasource_runtime, prefix, bucket, user_id, all_files, datasource_info, next_page_parameters
|
||||
)
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
|
||||
|
||||
class PipelineQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(
|
||||
event,
|
||||
QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent
|
||||
| QueueWorkflowPartialSuccessEvent,
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedError()
|
||||
|
|
@ -0,0 +1,280 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
RagPipelineGenerateEntity,
|
||||
)
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document, Pipeline
|
||||
from models.enums import UserFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
Pipeline Application Runner
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
|
||||
def _get_app_id(self) -> str:
|
||||
return self.application_generate_entity.app_config.app_id
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
"""
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(PipelineConfig, app_config)
|
||||
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
db.session.close()
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = SystemVariable(
|
||||
files=files,
|
||||
user_id=user_id,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
document_id=self.application_generate_entity.document_id,
|
||||
batch=self.application_generate_entity.batch,
|
||||
dataset_id=self.application_generate_entity.dataset_id,
|
||||
datasource_type=self.application_generate_entity.datasource_type,
|
||||
datasource_info=self.application_generate_entity.datasource_info,
|
||||
invoke_from=self.application_generate_entity.invoke_from.value,
|
||||
)
|
||||
|
||||
rag_pipeline_variables = []
|
||||
if workflow.rag_pipeline_variables:
|
||||
for v in workflow.rag_pipeline_variables:
|
||||
rag_pipeline_variable = RAGPipelineVariable(**v)
|
||||
if (
|
||||
rag_pipeline_variable.belong_to_node_id
|
||||
in (self.application_generate_entity.start_node_id, "shared")
|
||||
) and rag_pipeline_variable.variable in inputs:
|
||||
rag_pipeline_variables.append(
|
||||
RAGPipelineVariableInput(
|
||||
variable=rag_pipeline_variable,
|
||||
value=inputs[rag_pipeline_variable.variable],
|
||||
)
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=rag_pipeline_variables,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
graph = self._init_rag_pipeline_graph(
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
start_node_id=self.application_generate_entity.start_node_id,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
self._update_document_status(
|
||||
event, self.application_generate_entity.document_id, self.application_generate_entity.dataset_id
|
||||
)
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def _init_rag_pipeline_graph(
|
||||
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None
|
||||
) -> Graph:
|
||||
"""
|
||||
Init pipeline graph
|
||||
"""
|
||||
graph_config = workflow.graph_dict
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
# nodes = graph_config.get("nodes", [])
|
||||
# edges = graph_config.get("edges", [])
|
||||
# real_run_nodes = []
|
||||
# real_edges = []
|
||||
# exclude_node_ids = []
|
||||
# for node in nodes:
|
||||
# node_id = node.get("id")
|
||||
# node_type = node.get("data", {}).get("type", "")
|
||||
# if node_type == "datasource":
|
||||
# if start_node_id != node_id:
|
||||
# exclude_node_ids.append(node_id)
|
||||
# continue
|
||||
# real_run_nodes.append(node)
|
||||
|
||||
# for edge in edges:
|
||||
# if edge.get("source") in exclude_node_ids:
|
||||
# continue
|
||||
# real_edges.append(edge)
|
||||
# graph_config = dict(graph_config)
|
||||
# graph_config["nodes"] = real_run_nodes
|
||||
# graph_config["edges"] = real_edges
|
||||
# init graph
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph
|
||||
|
||||
def _update_document_status(self, event: GraphEngineEvent, document_id: str | None, dataset_id: str | None) -> None:
|
||||
"""
|
||||
Update document status
|
||||
"""
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
if document_id and dataset_id:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = event.error or "Unknown error"
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
|
@ -53,7 +53,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -67,7 +66,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -81,7 +79,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
|
|
@ -94,7 +91,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
|
|
@ -186,7 +182,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
|
|
@ -200,7 +195,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
|
|
@ -214,7 +208,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
|
|
@ -237,7 +230,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
"variable_loader": variable_loader,
|
||||
},
|
||||
)
|
||||
|
|
@ -432,17 +424,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = session.scalar(
|
||||
|
|
@ -472,7 +454,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
from typing import Optional, cast
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
|
|
@ -9,13 +9,14 @@ from core.app.entities.app_invoke_entities import (
|
|||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -31,7 +32,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
) -> None:
|
||||
|
|
@ -41,7 +41,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
app_id=application_generate_entity.app_config.app_id,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
|
||||
|
|
@ -52,24 +51,30 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
|
@ -92,15 +97,26 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
conversation_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
# Create Redis command channel for this workflow execution
|
||||
task_id = self.application_generate_entity.task_id
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
command_channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
||||
graph=graph,
|
||||
graph_config=self._workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
|
|
@ -111,11 +127,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
|
|||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
|
|
@ -25,14 +26,9 @@ from core.app.entities.queue_entities import (
|
|||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
|
|
@ -57,8 +53,8 @@ from core.app.entities.task_entities import (
|
|||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
|
@ -350,9 +346,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[
|
||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
||||
],
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
|
|
@ -371,32 +365,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
|
||||
def _handle_parallel_branch_started_event(
|
||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_start_resp
|
||||
|
||||
def _handle_parallel_branch_finished_events(
|
||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch finished events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_finish_resp
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
self, event: QueueIterationStartEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
|
|
@ -618,8 +586,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||
# Parallel branch events
|
||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
||||
# Iteration events
|
||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||
|
|
@ -634,7 +600,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
|
||||
def _dispatch_event(
|
||||
self,
|
||||
event: Any,
|
||||
event: AppQueueEvent,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
|
|
@ -661,8 +627,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
event,
|
||||
(
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
),
|
||||
):
|
||||
|
|
@ -675,17 +639,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
)
|
||||
return
|
||||
|
||||
# Handle parallel branch finished events with isinstance check
|
||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
||||
yield from self._handle_parallel_branch_finished_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# Handle workflow failed and stop events with isinstance check
|
||||
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
|
||||
yield from self._handle_workflow_failed_and_stop_events(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from collections.abc import Mapping
|
|||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
|
|
@ -13,14 +14,9 @@ from core.app.entities.queue_entities import (
|
|||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
|
|
@ -28,42 +24,39 @@ from core.app.entities.queue_entities import (
|
|||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
LoopRunFailedEvent,
|
||||
LoopRunNextEvent,
|
||||
LoopRunStartedEvent,
|
||||
LoopRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeInLoopFailedEvent,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
|
|
@ -79,7 +72,14 @@ class WorkflowBasedAppRunner:
|
|||
self._variable_loader = variable_loader
|
||||
self._app_id = app_id
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
def _init_graph(
|
||||
self,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
workflow_id: str = "",
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
|
|
@ -91,8 +91,28 @@ class WorkflowBasedAppRunner:
|
|||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Use the provided graph_runtime_state for consistent state management
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
|
@ -104,6 +124,7 @@ class WorkflowBasedAppRunner:
|
|||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
|
|
@ -145,8 +166,25 @@ class WorkflowBasedAppRunner:
|
|||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
|
@ -201,6 +239,7 @@ class WorkflowBasedAppRunner:
|
|||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
|
|
@ -242,8 +281,25 @@ class WorkflowBasedAppRunner:
|
|||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
|
@ -310,29 +366,21 @@ class WorkflowBasedAppRunner:
|
|||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
inputs: Mapping[str, Any] | None = {}
|
||||
process_data: Mapping[str, Any] | None = {}
|
||||
outputs: Mapping[str, Any] | None = {}
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
|
|
@ -343,6 +391,8 @@ class WorkflowBasedAppRunner:
|
|||
error=event.error,
|
||||
execution_metadata=execution_metadata,
|
||||
retry_index=event.retry_index,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
|
|
@ -350,44 +400,30 @@ class WorkflowBasedAppRunner:
|
|||
QueueNodeStartedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
start_at=event.start_at,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
start_at=event.start_at,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
|
|
@ -396,34 +432,18 @@ class WorkflowBasedAppRunner:
|
|||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
start_at=event.start_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
|
|
@ -434,93 +454,21 @@ class WorkflowBasedAppRunner:
|
|||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
start_at=event.start_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInIterationFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeInLoopFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInLoopFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_loop_id=event.in_loop_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
text=event.chunk,
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
|
|
@ -533,10 +481,10 @@ class WorkflowBasedAppRunner:
|
|||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, AgentLogEvent):
|
||||
elif isinstance(event, NodeRunAgentLogEvent):
|
||||
self._publish_event(
|
||||
QueueAgentLogEvent(
|
||||
id=event.id,
|
||||
id=event.message_id,
|
||||
label=event.label,
|
||||
node_execution_id=event.node_execution_id,
|
||||
parent_id=event.parent_id,
|
||||
|
|
@ -547,51 +495,13 @@ class WorkflowBasedAppRunner:
|
|||
node_id=event.node_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunSucceededEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunFailedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
elif isinstance(event, NodeRunIterationStartedEvent):
|
||||
self._publish_event(
|
||||
QueueIterationStartEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
|
|
@ -599,55 +509,41 @@ class WorkflowBasedAppRunner:
|
|||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
elif isinstance(event, NodeRunIterationNextEvent):
|
||||
self._publish_event(
|
||||
QueueIterationNextEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueIterationCompletedEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, LoopRunStartedEvent):
|
||||
elif isinstance(event, NodeRunLoopStartedEvent):
|
||||
self._publish_event(
|
||||
QueueLoopStartEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
|
|
@ -655,42 +551,32 @@ class WorkflowBasedAppRunner:
|
|||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, LoopRunNextEvent):
|
||||
elif isinstance(event, NodeRunLoopNextEvent):
|
||||
self._publish_event(
|
||||
QueueLoopNextEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_loop_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)):
|
||||
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueLoopCompletedEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, LoopRunFailedEvent) else None,
|
||||
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
|
@ -11,7 +11,7 @@ from core.file import File, FileUploadConfig
|
|||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
|
||||
class InvokeFrom(Enum):
|
||||
class InvokeFrom(StrEnum):
|
||||
"""
|
||||
Invoke From.
|
||||
"""
|
||||
|
|
@ -35,6 +35,7 @@ class InvokeFrom(Enum):
|
|||
# DEBUGGER indicates that this invocation is from
|
||||
# the workflow (or chatflow) edit page.
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED = "published"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
|
|
@ -240,3 +241,38 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||
inputs: dict
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
|
||||
|
||||
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||
"""
|
||||
RAG Pipeline Application Generate Entity.
|
||||
"""
|
||||
|
||||
# pipeline config
|
||||
pipeline_config: WorkflowUIBasedAppConfig
|
||||
datasource_type: str
|
||||
datasource_info: Mapping[str, Any]
|
||||
dataset_id: str
|
||||
batch: str
|
||||
document_id: Optional[str] = None
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
Single Iteration Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
class SingleLoopRunEntity(BaseModel):
|
||||
"""
|
||||
Single Loop Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
|
|
|
|||
|
|
@ -7,11 +7,9 @@ from pydantic import BaseModel
|
|||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class QueueEvent(StrEnum):
|
||||
|
|
@ -43,9 +41,6 @@ class QueueEvent(StrEnum):
|
|||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
MESSAGE_FILE = "message_file"
|
||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||
AGENT_LOG = "agent_log"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
|
|
@ -80,15 +75,7 @@ class QueueIterationStartEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
node_title: str
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
|
|
@ -108,20 +95,9 @@ class QueueIterationNextEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_title: str
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
duration: Optional[float] = None
|
||||
|
||||
|
||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
|
|
@ -134,15 +110,7 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
node_title: str
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
|
|
@ -163,7 +131,7 @@ class QueueLoopStartEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_title: str
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
|
@ -191,7 +159,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_title: str
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
|
@ -204,7 +172,6 @@ class QueueLoopNextEvent(AppQueueEvent):
|
|||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current loop
|
||||
duration: Optional[float] = None
|
||||
|
||||
|
||||
class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
|
|
@ -217,7 +184,7 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_title: str
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
|
@ -364,27 +331,24 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
|||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_run_index: int = 1
|
||||
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||
predecessor_node_id: Optional[str] = None
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
||||
provider_id: str
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
|
|
@ -396,7 +360,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
|
@ -417,10 +380,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
"""single iteration duration map"""
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
"""single loop duration map"""
|
||||
loop_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueAgentLogEvent(AppQueueEvent):
|
||||
|
|
@ -454,72 +413,6 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
|||
retry_index: int # retry index
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeInLoopFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInLoopFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeExceptionEvent entity
|
||||
|
|
@ -530,7 +423,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
|
@ -563,15 +455,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
|
|
@ -678,61 +562,3 @@ class WorkflowQueueMessage(QueueMessage):
|
|||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunStartedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunSucceededEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
error: str
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
|
|
@ -71,8 +71,6 @@ class StreamEvent(Enum):
|
|||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
NODE_RETRY = "node_retry"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
|
|
@ -447,54 +445,6 @@ class NodeRetryStreamResponse(StreamResponse):
|
|||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class ParallelBranchFinishedStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchFinishedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
|
|
@ -514,8 +464,6 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
|||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
||||
workflow_run_id: str
|
||||
|
|
@ -538,12 +486,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
|||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
pre_iteration_output: Optional[Any] = None
|
||||
extras: dict = Field(default_factory=dict)
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
duration: Optional[float] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
|
|
@ -577,8 +520,6 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
|||
execution_metadata: Optional[Mapping] = None
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
||||
workflow_run_id: str
|
||||
|
|
@ -633,7 +574,6 @@ class LoopNodeNextStreamResponse(StreamResponse):
|
|||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
duration: Optional[float] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_NEXT
|
||||
workflow_run_id: str
|
||||
|
|
|
|||
|
|
@ -105,6 +105,14 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||
|
||||
self.current_loop += 1
|
||||
|
||||
def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None:
|
||||
"""Run on datasource start."""
|
||||
if dify_config.DEBUG:
|
||||
print_text(
|
||||
"\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + str(datasource_inputs) + "\n",
|
||||
color=self.color,
|
||||
)
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
|
||||
|
||||
class DatasourcePlugin(ABC):
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
||||
@abstractmethod
|
||||
def datasource_provider_type(self) -> str:
|
||||
"""
|
||||
returns the type of the datasource provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
)
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class DatasourcePluginProviderController(ABC):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
"""
|
||||
returns whether the provider needs credentials
|
||||
|
||||
:return: whether the provider needs credentials
|
||||
"""
|
||||
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
if not manager.validate_datasource_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=self.entity.identity.name,
|
||||
credentials=credentials,
|
||||
):
|
||||
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
@abstractmethod
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
||||
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||
)
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if not credential_schema.required and credentials[credential_name] is None:
|
||||
continue
|
||||
|
||||
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
options = credential_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} should be one of {options}"
|
||||
)
|
||||
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema.required:
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||
|
||||
# the credential is not set currently, set the default value if needed
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type in {
|
||||
ProviderConfig.Type.SECRET_INPUT,
|
||||
ProviderConfig.Type.TEXT_INPUT,
|
||||
ProviderConfig.Type.SELECT,
|
||||
}:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
|
||||
|
||||
|
||||
class DatasourceRuntime(BaseModel):
|
||||
"""
|
||||
Meta data of a datasource call processing
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
datasource_id: Optional[str] = None
|
||||
invoke_from: Optional[InvokeFrom] = None
|
||||
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class FakeDatasourceRuntime(DatasourceRuntime):
|
||||
"""
|
||||
Fake datasource runtime for testing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
tenant_id="fake_tenant_id",
|
||||
datasource_id="fake_datasource_id",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
|
||||
credentials={},
|
||||
runtime_parameters={},
|
||||
)
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import MessageFile, UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasourceFileManager:
|
||||
@staticmethod
|
||||
def sign_file(datasource_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@staticmethod
|
||||
def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_raw(
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str],
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
filename: Optional[str] = None,
|
||||
) -> UploadFile:
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
unique_filename = f"{unique_name}{extension}"
|
||||
# default just as before
|
||||
present_filename = unique_filename
|
||||
if filename is not None:
|
||||
has_extension = len(filename.split(".")) > 1
|
||||
# Add extension flexibly
|
||||
present_filename = filename if has_extension else f"{filename}{extension}"
|
||||
filepath = f"datasources/{tenant_id}/{unique_filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=filepath,
|
||||
name=present_filename,
|
||||
size=len(file_binary),
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=user_id,
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(file_binary).hexdigest(),
|
||||
source_url="",
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(upload_file)
|
||||
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_url(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
file_url: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> UploadFile:
|
||||
# try to download image
|
||||
try:
|
||||
response = ssrf_proxy.get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||
|
||||
mimetype = (
|
||||
guess_type(file_url)[0]
|
||||
or response.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
or "application/octet-stream"
|
||||
)
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=filepath,
|
||||
name=filename,
|
||||
size=len(blob),
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=user_id,
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(blob).hexdigest(),
|
||||
source_url=file_url,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
upload_file: UploadFile | None = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
blob = storage.load_once(upload_file.key)
|
||||
|
||||
return blob, upload_file.mime_type
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile | None = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
# get tool file id
|
||||
if message_file.url is not None:
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
else:
|
||||
tool_file_id = None
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
@staticmethod
|
||||
def get_file_generator_by_upload_file_id(upload_file_id: str):
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param tool_file_id: the id of the tool file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
upload_file: UploadFile | None = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == upload_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not upload_file:
|
||||
return None, None
|
||||
|
||||
stream = storage.load_stream(upload_file.key)
|
||||
|
||||
return stream, upload_file.mime_type
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
# from core.file.datasource_file_parser import datasource_file_manager
|
||||
#
|
||||
# datasource_file_manager["manager"] = DatasourceFileManager
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
import logging
|
||||
from threading import Lock
|
||||
from typing import Union
|
||||
|
||||
import contexts
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.common_entities import I18nObject
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasourceManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_datasource_plugin_provider(
|
||||
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
|
||||
) -> DatasourcePluginProviderController:
|
||||
"""
|
||||
get the datasource plugin provider
|
||||
"""
|
||||
# check if context is set
|
||||
try:
|
||||
contexts.datasource_plugin_providers.get()
|
||||
except LookupError:
|
||||
contexts.datasource_plugin_providers.set({})
|
||||
contexts.datasource_plugin_providers_lock.set(Lock())
|
||||
|
||||
with contexts.datasource_plugin_providers_lock.get():
|
||||
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
|
||||
if provider_id in datasource_plugin_providers:
|
||||
return datasource_plugin_providers[provider_id]
|
||||
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
|
||||
if not provider_entity:
|
||||
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
|
||||
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case DatasourceProviderType.ONLINE_DRIVE:
|
||||
controller = OnlineDriveDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
datasource_plugin_providers[provider_id] = controller
|
||||
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def get_datasource_runtime(
|
||||
cls,
|
||||
provider_id: str,
|
||||
datasource_name: str,
|
||||
tenant_id: str,
|
||||
datasource_type: DatasourceProviderType,
|
||||
) -> DatasourcePlugin:
|
||||
"""
|
||||
get the datasource runtime
|
||||
|
||||
:param provider_type: the type of the provider
|
||||
:param provider_id: the id of the provider
|
||||
:param datasource_name: the name of the datasource
|
||||
:param tenant_id: the tenant id
|
||||
|
||||
:return: the datasource plugin
|
||||
"""
|
||||
return cls.get_datasource_plugin_provider(
|
||||
provider_id,
|
||||
tenant_id,
|
||||
datasource_type,
|
||||
).get_datasource(datasource_name)
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class DatasourceApiEntity(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[DatasourceParameter]] = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
|
||||
|
||||
class DatasourceProviderApiEntity(BaseModel):
|
||||
id: str
|
||||
author: str
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: str
|
||||
masked_credentials: Optional[dict] = None
|
||||
original_credentials: Optional[dict] = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
|
||||
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("datasources", mode="before")
|
||||
@classmethod
|
||||
def convert_none_to_empty_list(cls, v):
|
||||
return v if v is not None else []
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# -------------
|
||||
# overwrite datasource parameter types for temp fix
|
||||
datasources = jsonable_encoder(self.datasources)
|
||||
for datasource in datasources:
|
||||
if datasource.get("parameters"):
|
||||
for parameter in datasource.get("parameters"):
|
||||
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
|
||||
parameter["type"] = "files"
|
||||
# -------------
|
||||
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
"name": self.name,
|
||||
"plugin_id": self.plugin_id,
|
||||
"plugin_unique_identifier": self.plugin_unique_identifier,
|
||||
"description": self.description.to_dict(),
|
||||
"icon": self.icon,
|
||||
"label": self.label.to_dict(),
|
||||
"type": self.type.value,
|
||||
"team_credentials": self.masked_credentials,
|
||||
"is_team_authorization": self.is_team_authorization,
|
||||
"allow_delete": self.allow_delete,
|
||||
"datasources": datasources,
|
||||
"labels": self.labels,
|
||||
}
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
|
||||
en_US: str
|
||||
zh_Hans: Optional[str] = Field(default=None)
|
||||
pt_BR: Optional[str] = Field(default=None)
|
||||
ja_JP: Optional[str] = Field(default=None)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.zh_Hans = self.zh_Hans or self.en_US
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
|
|
@ -0,0 +1,363 @@
|
|||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.oauth import OAuthSchema
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
PluginParameterType,
|
||||
as_normal_type,
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum
|
||||
|
||||
|
||||
class DatasourceProviderType(enum.StrEnum):
|
||||
"""
|
||||
Enum class for datasource provider
|
||||
"""
|
||||
|
||||
ONLINE_DOCUMENT = "online_document"
|
||||
LOCAL_FILE = "local_file"
|
||||
WEBSITE_CRAWL = "website_crawl"
|
||||
ONLINE_DRIVE = "online_drive"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "DatasourceProviderType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class DatasourceParameter(PluginParameter):
|
||||
"""
|
||||
Overrides type
|
||||
"""
|
||||
|
||||
class DatasourceParameterType(enum.StrEnum):
|
||||
"""
|
||||
removes TOOLS_SELECTOR from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = PluginParameterType.STRING.value
|
||||
NUMBER = PluginParameterType.NUMBER.value
|
||||
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||
SELECT = PluginParameterType.SELECT.value
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||
FILE = PluginParameterType.FILE.value
|
||||
FILES = PluginParameterType.FILES.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
type: DatasourceParameterType = Field(..., description="The type of the parameter")
|
||||
description: I18nObject = Field(..., description="The description of the parameter")
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
cls,
|
||||
name: str,
|
||||
typ: DatasourceParameterType,
|
||||
required: bool,
|
||||
options: Optional[list[str]] = None,
|
||||
) -> "DatasourceParameter":
|
||||
"""
|
||||
get a simple datasource parameter
|
||||
|
||||
:param name: the name of the parameter
|
||||
:param llm_description: the description presented to the LLM
|
||||
:param typ: the type of the parameter
|
||||
:param required: if the parameter is required
|
||||
:param options: the options of the parameter
|
||||
"""
|
||||
# convert options to ToolParameterOption
|
||||
# FIXME fix the type error
|
||||
if options:
|
||||
option_objs = [
|
||||
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in options
|
||||
]
|
||||
else:
|
||||
option_objs = []
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
placeholder=None,
|
||||
type=typ,
|
||||
required=required,
|
||||
options=option_objs,
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class DatasourceIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the datasource")
|
||||
name: str = Field(..., description="The name of the datasource")
|
||||
label: I18nObject = Field(..., description="The label of the datasource")
|
||||
provider: str = Field(..., description="The provider of the datasource")
|
||||
icon: Optional[str] = None
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
identity: DatasourceIdentity
|
||||
parameters: list[DatasourceParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The label of the datasource")
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
|
||||
return v or []
|
||||
|
||||
|
||||
class DatasourceProviderIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
description: I18nObject = Field(..., description="The description of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
tags: Optional[list[ToolLabelEnum]] = Field(
|
||||
default=[],
|
||||
description="The tags of the tool",
|
||||
)
|
||||
|
||||
|
||||
class DatasourceProviderEntity(BaseModel):
|
||||
"""
|
||||
Datasource provider entity
|
||||
"""
|
||||
|
||||
identity: DatasourceProviderIdentity
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
oauth_schema: Optional[OAuthSchema] = None
|
||||
provider_type: DatasourceProviderType
|
||||
|
||||
|
||||
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
||||
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DatasourceInvokeMeta(BaseModel):
|
||||
"""
|
||||
Datasource invoke meta
|
||||
"""
|
||||
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: Optional[str] = None
|
||||
tool_config: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "DatasourceInvokeMeta":
|
||||
"""
|
||||
Get an empty instance of DatasourceInvokeMeta
|
||||
"""
|
||||
return cls(time_cost=0.0, error=None, tool_config={})
|
||||
|
||||
@classmethod
|
||||
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
|
||||
"""
|
||||
Get an instance of DatasourceInvokeMeta with error
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error, tool_config={})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"time_cost": self.time_cost,
|
||||
"error": self.error,
|
||||
"tool_config": self.tool_config,
|
||||
}
|
||||
|
||||
|
||||
class DatasourceLabel(BaseModel):
|
||||
"""
|
||||
Datasource label
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class DatasourceInvokeFrom(Enum):
|
||||
"""
|
||||
Enum class for datasource invoke
|
||||
"""
|
||||
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
|
||||
|
||||
class OnlineDocumentPage(BaseModel):
|
||||
"""
|
||||
Online document page
|
||||
"""
|
||||
|
||||
page_id: str = Field(..., description="The page id")
|
||||
page_name: str = Field(..., description="The page title")
|
||||
page_icon: Optional[dict] = Field(None, description="The page icon")
|
||||
type: str = Field(..., description="The type of the page")
|
||||
last_edited_time: str = Field(..., description="The last edited time")
|
||||
parent_id: Optional[str] = Field(None, description="The parent page id")
|
||||
|
||||
|
||||
class OnlineDocumentInfo(BaseModel):
|
||||
"""
|
||||
Online document info
|
||||
"""
|
||||
|
||||
workspace_id: Optional[str] = Field(None, description="The workspace id")
|
||||
workspace_name: Optional[str] = Field(None, description="The workspace name")
|
||||
workspace_icon: Optional[str] = Field(None, description="The workspace icon")
|
||||
total: int = Field(..., description="The total number of documents")
|
||||
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
|
||||
|
||||
|
||||
class OnlineDocumentPagesMessage(BaseModel):
|
||||
"""
|
||||
Get online document pages response
|
||||
"""
|
||||
|
||||
result: list[OnlineDocumentInfo]
|
||||
|
||||
|
||||
class GetOnlineDocumentPageContentRequest(BaseModel):
|
||||
"""
|
||||
Get online document page content request
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(..., description="The workspace id")
|
||||
page_id: str = Field(..., description="The page id")
|
||||
type: str = Field(..., description="The type of the page")
|
||||
|
||||
|
||||
class OnlineDocumentPageContent(BaseModel):
|
||||
"""
|
||||
Online document page content
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(..., description="The workspace id")
|
||||
page_id: str = Field(..., description="The page id")
|
||||
content: str = Field(..., description="The content of the page")
|
||||
|
||||
|
||||
class GetOnlineDocumentPageContentResponse(BaseModel):
|
||||
"""
|
||||
Get online document page content response
|
||||
"""
|
||||
|
||||
result: OnlineDocumentPageContent
|
||||
|
||||
|
||||
class GetWebsiteCrawlRequest(BaseModel):
|
||||
"""
|
||||
Get website crawl request
|
||||
"""
|
||||
|
||||
crawl_parameters: dict = Field(..., description="The crawl parameters")
|
||||
|
||||
|
||||
class WebSiteInfoDetail(BaseModel):
|
||||
source_url: str = Field(..., description="The url of the website")
|
||||
content: str = Field(..., description="The content of the website")
|
||||
title: str = Field(..., description="The title of the website")
|
||||
description: str = Field(..., description="The description of the website")
|
||||
|
||||
|
||||
class WebSiteInfo(BaseModel):
|
||||
"""
|
||||
Website info
|
||||
"""
|
||||
|
||||
status: Optional[str] = Field(..., description="crawl job status")
|
||||
web_info_list: Optional[list[WebSiteInfoDetail]] = []
|
||||
total: Optional[int] = Field(default=0, description="The total number of websites")
|
||||
completed: Optional[int] = Field(default=0, description="The number of completed websites")
|
||||
|
||||
|
||||
class WebsiteCrawlMessage(BaseModel):
|
||||
"""
|
||||
Get website crawl response
|
||||
"""
|
||||
|
||||
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
|
||||
|
||||
|
||||
class DatasourceMessage(ToolInvokeMessage):
|
||||
pass
|
||||
|
||||
|
||||
#########################
|
||||
# Online drive file
|
||||
#########################
|
||||
|
||||
|
||||
class OnlineDriveFile(BaseModel):
|
||||
"""
|
||||
Online drive file
|
||||
"""
|
||||
|
||||
id: str = Field(..., description="The file ID")
|
||||
name: str = Field(..., description="The file name")
|
||||
size: int = Field(..., description="The file size")
|
||||
type: str = Field(..., description="The file type: folder or file")
|
||||
|
||||
|
||||
class OnlineDriveFileBucket(BaseModel):
|
||||
"""
|
||||
Online drive file bucket
|
||||
"""
|
||||
|
||||
bucket: Optional[str] = Field(None, description="The file bucket")
|
||||
files: list[OnlineDriveFile] = Field(..., description="The file list")
|
||||
is_truncated: bool = Field(False, description="Whether the result is truncated")
|
||||
next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||
"""
|
||||
Get online drive file list request
|
||||
"""
|
||||
|
||||
bucket: Optional[str] = Field(None, description="The file bucket")
|
||||
prefix: str = Field(..., description="The parent folder ID")
|
||||
max_keys: int = Field(20, description="Page size for pagination")
|
||||
next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesResponse(BaseModel):
|
||||
"""
|
||||
Get online drive file list response
|
||||
"""
|
||||
|
||||
result: list[OnlineDriveFileBucket] = Field(..., description="The list of file buckets")
|
||||
|
||||
|
||||
class OnlineDriveDownloadFileRequest(BaseModel):
|
||||
"""
|
||||
Get online drive file
|
||||
"""
|
||||
|
||||
id: str = Field(..., description="The id of the file")
|
||||
bucket: Optional[str] = Field(None, description="The name of the bucket")
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
|
||||
|
||||
|
||||
class DatasourceProviderNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceParameterValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceProviderCredentialValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceNotSupportedError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceInvokeError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceApiSchemaError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceEngineInvokeError(Exception):
|
||||
meta: DatasourceInvokeMeta
|
||||
|
||||
def __init__(self, meta, **kwargs):
|
||||
self.meta = meta
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
|
||||
|
||||
class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
|
||||
|
||||
|
||||
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return LocalFileDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceMessage,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDocumentPagesMessage,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
|
||||
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def get_online_document_pages(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
provider_type: str,
|
||||
) -> Generator[OnlineDocumentPagesMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_online_document_pages(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def get_online_document_page_content(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_online_document_page_content(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
|
||||
|
||||
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return OnlineDocumentDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceMessage,
|
||||
DatasourceProviderType,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
OnlineDriveBrowseFilesResponse,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
|
||||
class OnlineDriveDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def online_drive_browse_files(
|
||||
self,
|
||||
user_id: str,
|
||||
request: OnlineDriveBrowseFilesRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.online_drive_browse_files(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
request=request,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def online_drive_download_file(
|
||||
self,
|
||||
user_id: str,
|
||||
request: OnlineDriveDownloadFileRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.online_drive_download_file(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
request=request,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DRIVE
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
|
||||
|
||||
class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> OnlineDriveDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return OnlineDriveDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
|
@ -0,0 +1,265 @@
|
|||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
|
||||
|
||||
class ProviderConfigEncrypter(BaseModel):
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_type: str
|
||||
provider_identity: str
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
try:
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cache.set(data)
|
||||
return data
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
|
||||
class ToolParameterConfigurationManager:
|
||||
"""
|
||||
Tool parameter configuration manager
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tool_runtime: Tool
|
||||
provider_name: str
|
||||
provider_type: ToolProviderType
|
||||
identity_id: str
|
||||
|
||||
def __init__(
|
||||
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.tool_runtime = tool_runtime
|
||||
self.provider_name = provider_name
|
||||
self.provider_type = provider_type
|
||||
self.identity_id = identity_id
|
||||
|
||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
deep copy parameters
|
||||
"""
|
||||
return deepcopy(parameters)
|
||||
|
||||
def _merge_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
merge parameters
|
||||
"""
|
||||
# get tool parameters
|
||||
tool_parameters = self.tool_runtime.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = self.tool_runtime.get_runtime_parameters()
|
||||
# override parameters
|
||||
current_parameters = tool_parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return current_parameters
|
||||
|
||||
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool parameters
|
||||
|
||||
return a deep copy of parameters with masked values
|
||||
"""
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
if len(parameters[parameter.name]) > 6:
|
||||
parameters[parameter.name] = (
|
||||
parameters[parameter.name][:2]
|
||||
+ "*" * (len(parameters[parameter.name]) - 4)
|
||||
+ parameters[parameter.name][-2:]
|
||||
)
|
||||
else:
|
||||
parameters[parameter.name] = "*" * len(parameters[parameter.name])
|
||||
|
||||
return parameters
|
||||
|
||||
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
encrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with encrypted values
|
||||
"""
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
parameters[parameter.name] = encrypted
|
||||
|
||||
return parameters
|
||||
|
||||
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cached_parameters = cache.get()
|
||||
if cached_parameters:
|
||||
return cached_parameters
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
has_secret_input = False
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
try:
|
||||
has_secret_input = True
|
||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_secret_input:
|
||||
cache.set(parameters)
|
||||
|
||||
return parameters
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cache.delete()
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
import logging
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional
|
||||
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasourceFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_datasource_invoke_messages(
|
||||
cls,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
"""
|
||||
Transform datasource message and handle file download
|
||||
"""
|
||||
for message in messages:
|
||||
if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}:
|
||||
yield message
|
||||
elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance(
|
||||
message.message, DatasourceMessage.TextMessage
|
||||
):
|
||||
# try to download image
|
||||
try:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
file = DatasourceFileManager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}"
|
||||
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
except Exception as e:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.TEXT,
|
||||
message=DatasourceMessage.TextMessage(
|
||||
text=f"Failed to download image: {message.message.text}: {e}"
|
||||
),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
meta = message.meta or {}
|
||||
# get filename from meta
|
||||
filename = meta.get("file_name", None)
|
||||
|
||||
mimetype = meta.get("mime_type")
|
||||
if not mimetype:
|
||||
mimetype = guess_type(filename)[0] or "application/octet-stream"
|
||||
|
||||
# if message is str, encode it to bytes
|
||||
|
||||
if not isinstance(message.message, DatasourceMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
file = DatasourceFileManager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_binary=message.message.blob,
|
||||
mimetype=mimetype,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type))
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BINARY_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
meta = message.meta or {}
|
||||
file = meta.get("file", None)
|
||||
if isinstance(file, File):
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file.related_id is not None
|
||||
url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension)
|
||||
if file.type == FileType.IMAGE:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield message
|
||||
else:
|
||||
yield message
|
||||
|
||||
@classmethod
|
||||
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
|
||||
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"
|
||||
|
|
@ -0,0 +1,389 @@
|
|||
import re
|
||||
import uuid
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
from yaml import YAMLError, safe_load # type: ignore
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
# set description to extra_info
|
||||
extra_info["description"] = openapi["info"].get("description", "")
|
||||
|
||||
if len(openapi["servers"]) == 0:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
if request_env:
|
||||
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
for path, path_item in openapi["paths"].items():
|
||||
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||
for method in methods:
|
||||
if method in path_item:
|
||||
interfaces.append(
|
||||
{
|
||||
"path": path,
|
||||
"method": method,
|
||||
"operation": path_item[method],
|
||||
}
|
||||
)
|
||||
|
||||
# get all parameters
|
||||
bundles = []
|
||||
for interface in interfaces:
|
||||
# convert parameters
|
||||
parameters = []
|
||||
if "parameters" in interface["operation"]:
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=parameter.get("required", False),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get("description"),
|
||||
default=parameter["schema"]["default"]
|
||||
if "schema" in parameter and "default" in parameter["schema"]
|
||||
else None,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
|
||||
if typ:
|
||||
tool_parameter.type = typ
|
||||
|
||||
parameters.append(tool_parameter)
|
||||
# create tool bundle
|
||||
# check if there is a request body
|
||||
if "requestBody" in interface["operation"]:
|
||||
request_body = interface["operation"]["requestBody"]
|
||||
if "content" in request_body:
|
||||
for content_type, content in request_body["content"].items():
|
||||
# if there is a reference, get the reference and overwrite the content
|
||||
if "schema" not in content:
|
||||
continue
|
||||
|
||||
if "$ref" in content["schema"]:
|
||||
# get the reference
|
||||
root = openapi
|
||||
reference = content["schema"]["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
# overwrite the content
|
||||
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
|
||||
parameters.append(tool)
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
for parameter in parameters:
|
||||
if parameter.name not in parameters_count:
|
||||
parameters_count[parameter.name] = 0
|
||||
parameters_count[parameter.name] += 1
|
||||
for name, count in parameters_count.items():
|
||||
if count > 1:
|
||||
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
|
||||
|
||||
# check if there is a operation id, use $path_$method as operation id if not
|
||||
if "operationId" not in interface["operation"]:
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = interface["path"]
|
||||
if interface["path"].startswith("/"):
|
||||
path = interface["path"][1:]
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||
if not path:
|
||||
path = str(uuid.uuid4())
|
||||
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
server_url=server_url + interface["path"],
|
||||
method=interface["method"],
|
||||
summary=interface["operation"]["description"]
|
||||
if "description" in interface["operation"]
|
||||
else interface["operation"].get("summary", None),
|
||||
operation_id=interface["operation"]["operationId"],
|
||||
parameters=parameters,
|
||||
author="",
|
||||
icon=None,
|
||||
openapi=interface["operation"],
|
||||
)
|
||||
)
|
||||
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
|
||||
parameter = parameter or {}
|
||||
typ: Optional[str] = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
if "type" in parameter:
|
||||
typ = parameter["type"]
|
||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||
typ = parameter["schema"]["type"]
|
||||
|
||||
if typ in {"integer", "number"}:
|
||||
return ToolParameter.ToolParameterType.NUMBER
|
||||
elif typ == "boolean":
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
elif typ == "string":
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
elif typ == "array":
|
||||
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = safe_load(yaml)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
:param swagger: the swagger dict
|
||||
:return: the openapi dict
|
||||
"""
|
||||
# convert swagger to openapi
|
||||
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
|
||||
|
||||
servers = swagger.get("servers", [])
|
||||
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
"description": info.get("description", "Swagger"),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
},
|
||||
"servers": swagger["servers"],
|
||||
"paths": {},
|
||||
"components": {"schemas": {}},
|
||||
}
|
||||
|
||||
# check paths
|
||||
if "paths" not in swagger or len(swagger["paths"]) == 0:
|
||||
raise ToolApiSchemaError("No paths found in the swagger yaml.")
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
openapi["paths"][path] = {}
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
|
||||
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||
"description" not in operation or len(operation["description"]) == 0
|
||||
):
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = {
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": operation.get("parameters", []),
|
||||
"responses": operation.get("responses", {}),
|
||||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger["definitions"].items():
|
||||
openapi["components"]["schemas"][name] = definition
|
||||
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
:param json: the json string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
try:
|
||||
openai_plugin = json_loads(json)
|
||||
api = openai_plugin["api"]
|
||||
api_url = api["url"]
|
||||
api_type = api["type"]
|
||||
except JSONDecodeError:
|
||||
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||
|
||||
if api_type != "openapi":
|
||||
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||
|
||||
# get openapi yaml
|
||||
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
|
||||
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
|
||||
response.text, extra_info=extra_info, warning=warning
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
:param content: the content
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: tools bundle, schema_type
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
content = content.strip()
|
||||
loaded_content = None
|
||||
json_error = None
|
||||
yaml_error = None
|
||||
|
||||
try:
|
||||
loaded_content = json_loads(content)
|
||||
except JSONDecodeError as e:
|
||||
json_error = e
|
||||
|
||||
if loaded_content is None:
|
||||
try:
|
||||
loaded_content = safe_load(content)
|
||||
except YAMLError as e:
|
||||
yaml_error = e
|
||||
if loaded_content is None:
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
|
||||
f" yaml error: {str(yaml_error)}"
|
||||
)
|
||||
|
||||
swagger_error = None
|
||||
openapi_error = None
|
||||
openapi_plugin_error = None
|
||||
schema_type = None
|
||||
|
||||
try:
|
||||
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.OPENAPI.value
|
||||
return openapi, schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
openapi_error = e
|
||||
|
||||
# openai parse error, fallback to swagger
|
||||
try:
|
||||
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.SWAGGER.value
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
converted_swagger, extra_info=extra_info, warning=warning
|
||||
), schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
swagger_error = e
|
||||
|
||||
# swagger parse error, fallback to openai plugin
|
||||
try:
|
||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||
)
|
||||
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
|
||||
except ToolNotSupportedError as e:
|
||||
# maybe it's not plugin at all
|
||||
openapi_plugin_error = e
|
||||
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
|
||||
f" openapi plugin error: {str(openapi_plugin_error)}"
|
||||
)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
import re
|
||||
|
||||
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||
return re.sub(pattern, "", text)
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
import uuid
|
||||
|
||||
|
||||
def is_valid_uuid(uuid_str: str) -> bool:
|
||||
try:
|
||||
uuid.UUID(uuid_str)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
):
|
||||
"""
|
||||
check is synced
|
||||
|
||||
raise ValueError if not synced
|
||||
"""
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml # type: ignore
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||
"""
|
||||
Safe loading a YAML file
|
||||
:param file_path: the path of the YAML file
|
||||
:param ignore_error:
|
||||
if True, return default_value if error occurs and the error will be logged in debug level
|
||||
if False, raise error if error occurs
|
||||
:param default_value: the value returned when errors ignored
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
if not file_path or not Path(file_path).exists():
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content or default_value
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
WebsiteCrawlMessage,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
|
||||
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def get_website_crawl(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
provider_type: str,
|
||||
) -> Generator[WebsiteCrawlMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_website_crawl(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||
|
||||
|
||||
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceProviderEntityWithPlugin,
|
||||
plugin_id: str,
|
||||
plugin_unique_identifier: str,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return WebsiteCrawlDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
|
@ -17,3 +17,27 @@ class IndexingEstimate(BaseModel):
|
|||
total_segments: int
|
||||
preview: list[PreviewDetail]
|
||||
qa_preview: Optional[list[QAPreviewDetail]] = None
|
||||
|
||||
|
||||
class PipelineDataset(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
chunk_structure: str
|
||||
|
||||
|
||||
class PipelineDocument(BaseModel):
|
||||
id: str
|
||||
position: int
|
||||
data_source_type: str
|
||||
data_source_info: Optional[dict] = None
|
||||
name: str
|
||||
indexing_status: str
|
||||
error: Optional[str] = None
|
||||
enabled: bool
|
||||
|
||||
|
||||
class PipelineGenerateResponse(BaseModel):
|
||||
batch: str
|
||||
dataset: PipelineDataset
|
||||
documents: list[PipelineDocument]
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ from core.model_runtime.entities.provider_entities import (
|
|||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import (
|
||||
|
|
@ -41,6 +40,7 @@ from models.provider import (
|
|||
ProviderType,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -627,6 +627,7 @@ class ProviderConfiguration(BaseModel):
|
|||
Get custom model credentials.
|
||||
"""
|
||||
# get provider model
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
|
|
@ -1124,6 +1125,7 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
Get provider model setting.
|
||||
"""
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
|
|
@ -1207,6 +1209,7 @@ class ProviderConfiguration(BaseModel):
|
|||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
|
|
@ -1340,7 +1343,7 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.datasource import datasource_file_manager
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
|
||||
tool_file_manager: dict[str, Any] = {"manager": None}
|
||||
|
||||
|
||||
class DatasourceFileParser:
|
||||
@staticmethod
|
||||
def get_datasource_file_manager() -> "DatasourceFileManager":
|
||||
return cast("DatasourceFileManager", datasource_file_manager["manager"])
|
||||
|
|
@ -20,6 +20,7 @@ class FileTransferMethod(StrEnum):
|
|||
REMOTE_URL = "remote_url"
|
||||
LOCAL_FILE = "local_file"
|
||||
TOOL_FILE = "tool_file"
|
||||
DATASOURCE_FILE = "datasource_file"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
|
|
|
|||
|
|
@ -97,7 +97,11 @@ def to_prompt_message_content(
|
|||
|
||||
|
||||
def download(f: File, /):
|
||||
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
|
||||
if f.transfer_method in (
|
||||
FileTransferMethod.TOOL_FILE,
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
FileTransferMethod.DATASOURCE_FILE,
|
||||
):
|
||||
return _download_file_content(f._storage_key)
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
|
|
|
|||
|
|
@ -115,11 +115,10 @@ class File(BaseModel):
|
|||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE or self.transfer_method == FileTransferMethod.DATASOURCE_FILE:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return {
|
||||
"dify_model_identity": FILE_MODEL_IDENTITY,
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ def obfuscated_token(token: str):
|
|||
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
from models.engine import db
|
||||
|
||||
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_provider_name(
|
||||
providers: Sequence[Any], credential_type: CredentialType, fallback_context: str = "provider"
|
||||
) -> str:
|
||||
try:
|
||||
return generate_incremental_name(
|
||||
[provider.name for provider in providers],
|
||||
f"{credential_type.get_name()}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Error generating next provider name for %r: %r", fallback_context, e)
|
||||
return f"{credential_type.get_name()} 1"
|
||||
|
||||
|
||||
def generate_incremental_name(
|
||||
names: Sequence[str],
|
||||
default_pattern: str,
|
||||
) -> str:
|
||||
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
|
||||
numbers = []
|
||||
|
||||
for name in names:
|
||||
if not name:
|
||||
continue
|
||||
match = re.match(pattern, name.strip())
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
|
||||
if not numbers:
|
||||
return f"{default_pattern} 1"
|
||||
|
||||
max_number = max(numbers)
|
||||
return f"{default_pattern} {max_number + 1}"
|
||||
|
|
@ -359,6 +359,7 @@ class IndexingRunner:
|
|||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
|
|
|
|||
|
|
@ -28,9 +28,10 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
|||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.node_events import AgentLogEvent
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import App, Message, WorkflowNodeExecutionModel, db
|
||||
from models import App, Message, WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from collections.abc import Sequence
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
|
|
@ -32,7 +33,12 @@ class TokenBufferMemory:
|
|||
self.model_instance = model_instance
|
||||
|
||||
def _build_prompt_message_with_files(
|
||||
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
|
||||
self,
|
||||
message_files: Sequence[MessageFile],
|
||||
text_content: str,
|
||||
message: Message,
|
||||
app_record,
|
||||
is_user_message: bool,
|
||||
) -> PromptMessage:
|
||||
"""
|
||||
Build prompt message with files.
|
||||
|
|
@ -98,80 +104,80 @@ class TokenBufferMemory:
|
|||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
"""
|
||||
app_record = self.conversation.app
|
||||
with Session(db.engine) as session:
|
||||
app_record = self.conversation.app
|
||||
|
||||
# fetch limited messages, and return reversed
|
||||
stmt = (
|
||||
select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc())
|
||||
)
|
||||
# fetch limited messages, and return reversed
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == self.conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
|
||||
if message_limit and message_limit > 0:
|
||||
message_limit = min(message_limit, 500)
|
||||
else:
|
||||
message_limit = 500
|
||||
if message_limit and message_limit > 0:
|
||||
message_limit = min(message_limit, 500)
|
||||
else:
|
||||
message_limit = 500
|
||||
|
||||
stmt = stmt.limit(message_limit)
|
||||
stmt = stmt.limit(message_limit)
|
||||
|
||||
messages = db.session.scalars(stmt).all()
|
||||
messages = session.scalars(stmt).all()
|
||||
|
||||
# instead of all messages from the conversation, we only need to extract messages
|
||||
# that belong to the thread of last message
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
# instead of all messages from the conversation, we only need to extract messages
|
||||
# that belong to the thread of last message
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
|
||||
messages = list(reversed(thread_messages))
|
||||
messages = list(reversed(thread_messages))
|
||||
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
# Process user message with files
|
||||
user_files = (
|
||||
db.session.query(MessageFile)
|
||||
.where(
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
# Process user message with files
|
||||
user_file_query = select(MessageFile).where(
|
||||
MessageFile.message_id == message.id,
|
||||
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
user_files = session.scalars(user_file_query).all()
|
||||
|
||||
if user_files:
|
||||
user_prompt_message = self._build_prompt_message_with_files(
|
||||
message_files=user_files,
|
||||
text_content=message.query,
|
||||
message=message,
|
||||
app_record=app_record,
|
||||
is_user_message=True,
|
||||
if user_files:
|
||||
user_prompt_message = self._build_prompt_message_with_files(
|
||||
message_files=user_files,
|
||||
text_content=message.query,
|
||||
message=message,
|
||||
app_record=app_record,
|
||||
is_user_message=True,
|
||||
)
|
||||
prompt_messages.append(user_prompt_message)
|
||||
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
|
||||
# Process assistant message with files
|
||||
assistant_file_query = select(MessageFile).where(
|
||||
MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant"
|
||||
)
|
||||
prompt_messages.append(user_prompt_message)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
assistant_files = session.scalars(assistant_file_query).all()
|
||||
|
||||
# Process assistant message with files
|
||||
assistant_files = (
|
||||
db.session.query(MessageFile)
|
||||
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
|
||||
.all()
|
||||
)
|
||||
if assistant_files:
|
||||
assistant_prompt_message = self._build_prompt_message_with_files(
|
||||
message_files=assistant_files,
|
||||
text_content=message.answer,
|
||||
message=message,
|
||||
app_record=app_record,
|
||||
is_user_message=False,
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message)
|
||||
else:
|
||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
if assistant_files:
|
||||
assistant_prompt_message = self._build_prompt_message_with_files(
|
||||
message_files=assistant_files,
|
||||
text_content=message.answer,
|
||||
message=message,
|
||||
app_record=app_record,
|
||||
is_user_message=False,
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message)
|
||||
else:
|
||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if curr_message_tokens > max_token_limit:
|
||||
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
|
|
|
|||
|
|
@ -24,8 +24,7 @@ from core.model_runtime.errors.invoke import (
|
|||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
|
||||
|
||||
class AIModel(BaseModel):
|
||||
|
|
@ -53,6 +52,8 @@ class AIModel(BaseModel):
|
|||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
|
||||
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
|
|
@ -140,6 +141,8 @@ class AIModel(BaseModel):
|
|||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
||||
# sort credentials
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from core.model_runtime.entities.model_entities import (
|
|||
PriceType,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -142,6 +141,8 @@ class LargeLanguageModel(AIModel):
|
|||
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
|
||||
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
result = plugin_model_manager.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
|
|
@ -340,6 +341,8 @@ class LargeLanguageModel(AIModel):
|
|||
:return:
|
||||
"""
|
||||
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_llm_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from pydantic import ConfigDict
|
|||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class ModerationModel(AIModel):
|
||||
|
|
@ -31,6 +30,8 @@ class ModerationModel(AIModel):
|
|||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue